Skip to main content

scirs2_stats/variational/
advi.rs

1//! Automatic Differentiation Variational Inference (ADVI)
2//!
3//! Implements ADVI (Kucukelbir et al. 2017) — transforms constrained parameters
4//! to unconstrained real space, then fits a Gaussian variational approximation
5//! by maximizing the ELBO via the reparameterization trick and Adam optimizer.
6//!
7//! Supports:
8//! - Mean-field approximation: `q(theta) = prod_i N(mu_i, sigma_i^2)`
9//! - Full-rank approximation: `q(theta) = N(mu, L L^T)` with Cholesky factor L
10//! - Automatic parameter transformations: log, logit, identity, bounded
11//! - Stochastic ELBO gradient estimation via reparameterization trick
12//! - Adam optimizer with configurable learning rate
13
14use crate::error::{StatsError, StatsResult};
15use scirs2_core::ndarray::{Array1, Array2};
16use std::f64::consts::PI;
17
18use super::{PosteriorResult, VariationalInference};
19
20// ============================================================================
21// Parameter Transforms
22// ============================================================================
23
24/// Transformation for mapping constrained parameters to unconstrained space
25#[derive(Debug, Clone)]
26pub enum AdviTransform {
27    /// Identity (unconstrained real line)
28    Identity,
29    /// Log transform (for positive parameters)
30    Log,
31    /// Logit transform (for parameters in `[0, 1]`)
32    Logit,
33    /// Scaled logit for bounded parameters in `[lower, upper]`
34    Bounded {
35        /// Lower bound
36        lower: f64,
37        /// Upper bound
38        upper: f64,
39    },
40}
41
42impl AdviTransform {
43    /// Map from unconstrained to constrained space
44    pub fn forward(&self, eta: f64) -> f64 {
45        match self {
46            AdviTransform::Identity => eta,
47            AdviTransform::Log => eta.exp(),
48            AdviTransform::Logit => 1.0 / (1.0 + (-eta).exp()),
49            AdviTransform::Bounded { lower, upper } => {
50                let s = 1.0 / (1.0 + (-eta).exp());
51                lower + (upper - lower) * s
52            }
53        }
54    }
55
56    /// Map from constrained to unconstrained space
57    pub fn inverse(&self, theta: f64) -> StatsResult<f64> {
58        match self {
59            AdviTransform::Identity => Ok(theta),
60            AdviTransform::Log => {
61                if theta <= 0.0 {
62                    return Err(StatsError::InvalidArgument(format!(
63                        "Log transform requires positive value, got {}",
64                        theta
65                    )));
66                }
67                Ok(theta.ln())
68            }
69            AdviTransform::Logit => {
70                if theta <= 0.0 || theta >= 1.0 {
71                    return Err(StatsError::InvalidArgument(format!(
72                        "Logit transform requires value in (0, 1), got {}",
73                        theta
74                    )));
75                }
76                Ok((theta / (1.0 - theta)).ln())
77            }
78            AdviTransform::Bounded { lower, upper } => {
79                if theta <= *lower || theta >= *upper {
80                    return Err(StatsError::InvalidArgument(format!(
81                        "Bounded transform requires value in ({}, {}), got {}",
82                        lower, upper, theta
83                    )));
84                }
85                let s = (theta - lower) / (upper - lower);
86                Ok((s / (1.0 - s)).ln())
87            }
88        }
89    }
90
91    /// Log absolute Jacobian determinant of the forward transform
92    /// needed for change-of-variables correction in the ELBO
93    pub fn log_det_jacobian(&self, eta: f64) -> f64 {
94        match self {
95            AdviTransform::Identity => 0.0,
96            AdviTransform::Log => eta,
97            AdviTransform::Logit => {
98                // d/d(eta) sigmoid(eta) = sigmoid(eta) * (1 - sigmoid(eta))
99                // log|J| = log(sigmoid(eta)) + log(1 - sigmoid(eta))
100                //        = -softplus(-eta) + (-softplus(eta))
101                //        = eta - 2*softplus(eta)   [numerically stable]
102                let sp = softplus(eta);
103                eta - 2.0 * sp
104            }
105            AdviTransform::Bounded { lower, upper } => {
106                let log_range = (upper - lower).ln();
107                let sp = softplus(eta);
108                log_range + eta - 2.0 * sp
109            }
110        }
111    }
112
113    /// Gradient of the log-det-Jacobian w.r.t. unconstrained parameter eta
114    pub fn grad_log_det_jacobian(&self, eta: f64) -> f64 {
115        match self {
116            AdviTransform::Identity => 0.0,
117            AdviTransform::Log => 1.0,
118            AdviTransform::Logit | AdviTransform::Bounded { .. } => {
119                // d/d(eta) [eta - 2*softplus(eta)] = 1 - 2*sigmoid(eta)
120                let s = sigmoid(eta);
121                1.0 - 2.0 * s
122            }
123        }
124    }
125
126    /// Gradient of the forward transform w.r.t. eta (d theta / d eta)
127    pub fn forward_grad(&self, eta: f64) -> f64 {
128        match self {
129            AdviTransform::Identity => 1.0,
130            AdviTransform::Log => eta.exp(),
131            AdviTransform::Logit => {
132                let s = sigmoid(eta);
133                s * (1.0 - s)
134            }
135            AdviTransform::Bounded { lower, upper } => {
136                let s = sigmoid(eta);
137                (upper - lower) * s * (1.0 - s)
138            }
139        }
140    }
141}
142
143/// Numerically stable softplus: log(1 + exp(x))
144fn softplus(x: f64) -> f64 {
145    if x > 20.0 {
146        x
147    } else if x < -20.0 {
148        x.exp()
149    } else {
150        (1.0 + x.exp()).ln()
151    }
152}
153
154/// Sigmoid function: 1 / (1 + exp(-x))
155fn sigmoid(x: f64) -> f64 {
156    if x >= 0.0 {
157        1.0 / (1.0 + (-x).exp())
158    } else {
159        let ex = x.exp();
160        ex / (1.0 + ex)
161    }
162}
163
164// ============================================================================
165// Approximation Type
166// ============================================================================
167
168/// Type of variational approximation
169#[derive(Debug, Clone, Copy, PartialEq)]
170pub enum AdviApproximation {
171    /// Mean-field: q(eta) = prod_i N(mu_i, sigma_i^2)
172    MeanField,
173    /// Full-rank: q(eta) = N(mu, L L^T) with lower-triangular Cholesky factor
174    FullRank,
175}
176
177// ============================================================================
178// Adam Optimizer (self-contained for ADVI)
179// ============================================================================
180
181/// Adam optimizer state for ADVI
182#[derive(Debug, Clone)]
183struct AdviAdamState {
184    m: Array1<f64>,
185    v: Array1<f64>,
186    t: usize,
187    beta1: f64,
188    beta2: f64,
189    epsilon: f64,
190}
191
192impl AdviAdamState {
193    fn new(n_params: usize) -> Self {
194        Self {
195            m: Array1::zeros(n_params),
196            v: Array1::zeros(n_params),
197            t: 0,
198            beta1: 0.9,
199            beta2: 0.999,
200            epsilon: 1e-8,
201        }
202    }
203
204    /// Compute Adam update direction (to be scaled by learning rate)
205    fn update(&mut self, grad: &Array1<f64>) -> Array1<f64> {
206        self.t += 1;
207        let n = grad.len();
208        let mut direction = Array1::zeros(n);
209        for i in 0..n {
210            self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * grad[i];
211            self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * grad[i] * grad[i];
212            let m_hat = self.m[i] / (1.0 - self.beta1.powi(self.t as i32));
213            let v_hat = self.v[i] / (1.0 - self.beta2.powi(self.t as i32));
214            direction[i] = m_hat / (v_hat.sqrt() + self.epsilon);
215        }
216        direction
217    }
218}
219
220// ============================================================================
221// ADVI Configuration
222// ============================================================================
223
224/// Configuration for ADVI
225#[derive(Debug, Clone)]
226pub struct AdviConfig {
227    /// Type of variational approximation
228    pub approximation: AdviApproximation,
229    /// Parameter transforms (one per dimension); if empty, all Identity
230    pub transforms: Vec<AdviTransform>,
231    /// Number of Monte Carlo samples for ELBO gradient estimation
232    pub num_samples: usize,
233    /// Learning rate for Adam optimizer
234    pub learning_rate: f64,
235    /// Maximum number of optimization iterations
236    pub max_iterations: usize,
237    /// Convergence tolerance on ELBO change
238    pub tolerance: f64,
239    /// Random seed for reproducibility
240    pub seed: u64,
241    /// Window size for checking convergence (average over last N ELBOs)
242    pub convergence_window: usize,
243}
244
245impl Default for AdviConfig {
246    fn default() -> Self {
247        Self {
248            approximation: AdviApproximation::MeanField,
249            transforms: Vec::new(),
250            num_samples: 10,
251            learning_rate: 0.01,
252            max_iterations: 10000,
253            tolerance: 1e-4,
254            seed: 42,
255            convergence_window: 50,
256        }
257    }
258}
259
260// ============================================================================
261// ADVI Struct
262// ============================================================================
263
264/// Automatic Differentiation Variational Inference
265///
266/// ADVI automatically transforms constrained parameters to unconstrained space,
267/// then optimizes a Gaussian variational approximation using the ELBO.
268///
269/// # Example
270/// ```no_run
271/// use scirs2_stats::variational::{Advi, AdviConfig, AdviApproximation, AdviTransform};
272/// use scirs2_core::ndarray::Array1;
273///
274/// let config = AdviConfig {
275///     approximation: AdviApproximation::MeanField,
276///     transforms: vec![AdviTransform::Identity, AdviTransform::Log],
277///     num_samples: 10,
278///     learning_rate: 0.01,
279///     max_iterations: 1000,
280///     ..Default::default()
281/// };
282///
283/// let mut advi = Advi::new(config);
284/// ```
285#[derive(Debug, Clone)]
286pub struct Advi {
287    /// Configuration
288    pub config: AdviConfig,
289}
290
291impl Advi {
292    /// Create a new ADVI instance with the given configuration
293    pub fn new(config: AdviConfig) -> Self {
294        Self { config }
295    }
296
297    /// Generate quasi-random standard normal samples using Box-Muller with
298    /// golden-ratio quasi-random sequences for reproducibility
299    fn generate_epsilon(&self, dim: usize, seed: u64) -> Array1<f64> {
300        let golden = 1.618033988749895_f64;
301        let plastic = 1.324717957244746_f64;
302        Array1::from_shape_fn(dim, |i| {
303            let u1 = ((seed as f64 * golden + i as f64 * plastic) % 1.0).abs();
304            let u2 = ((seed as f64 * plastic + i as f64 * golden) % 1.0).abs();
305            let u1 = u1.max(1e-10).min(1.0 - 1e-10);
306            let u2 = u2.max(1e-10).min(1.0 - 1e-10);
307            let r = (-2.0 * u1.ln()).sqrt();
308            r * (2.0 * PI * u2).cos()
309        })
310    }
311
312    /// Get transform for dimension i, defaulting to Identity if not specified
313    fn get_transform(&self, i: usize) -> &AdviTransform {
314        if i < self.config.transforms.len() {
315            &self.config.transforms[i]
316        } else {
317            // Return a static reference to Identity
318            &AdviTransform::Identity
319        }
320    }
321
322    /// Transform unconstrained parameters to constrained space
323    fn transform_to_constrained(&self, eta: &Array1<f64>) -> Array1<f64> {
324        Array1::from_shape_fn(eta.len(), |i| self.get_transform(i).forward(eta[i]))
325    }
326
327    /// Compute sum of log-det-Jacobians for all transforms
328    fn total_log_det_jacobian(&self, eta: &Array1<f64>) -> f64 {
329        (0..eta.len())
330            .map(|i| self.get_transform(i).log_det_jacobian(eta[i]))
331            .sum()
332    }
333
334    /// Compute gradient of total log-det-Jacobian w.r.t. eta
335    fn grad_log_det_jacobian(&self, eta: &Array1<f64>) -> Array1<f64> {
336        Array1::from_shape_fn(eta.len(), |i| {
337            self.get_transform(i).grad_log_det_jacobian(eta[i])
338        })
339    }
340
341    /// Compute gradient of constrained theta w.r.t. unconstrained eta
342    fn forward_grad(&self, eta: &Array1<f64>) -> Array1<f64> {
343        Array1::from_shape_fn(eta.len(), |i| self.get_transform(i).forward_grad(eta[i]))
344    }
345
346    /// Fit mean-field ADVI
347    fn fit_mean_field<F>(&self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
348    where
349        F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
350    {
351        // Variational parameters: mu (dim) and log_sigma (dim)
352        let n_params = 2 * dim;
353        let mut mu = Array1::zeros(dim);
354        let mut log_sigma = Array1::zeros(dim); // sigma = 1 initially
355
356        let mut adam = AdviAdamState::new(n_params);
357        let mut elbo_history = Vec::with_capacity(self.config.max_iterations);
358        let mut converged = false;
359
360        for iter in 0..self.config.max_iterations {
361            let mut elbo_sum = 0.0;
362            let mut grad_mu_sum = Array1::zeros(dim);
363            let mut grad_log_sigma_sum = Array1::zeros(dim);
364
365            for s in 0..self.config.num_samples {
366                let seed = self
367                    .config
368                    .seed
369                    .wrapping_add(iter as u64 * 1000)
370                    .wrapping_add(s as u64);
371                let epsilon = self.generate_epsilon(dim, seed);
372
373                // Reparameterization: eta = mu + sigma * epsilon
374                let sigma = log_sigma.mapv(f64::exp);
375                let eta = &mu + &(&sigma * &epsilon);
376
377                // Transform to constrained space
378                let theta = self.transform_to_constrained(&eta);
379
380                // Evaluate log joint and gradient
381                let (log_p, grad_theta) = log_joint(&theta)?;
382
383                // Log-det-Jacobian correction
384                let ldj = self.total_log_det_jacobian(&eta);
385                let grad_ldj = self.grad_log_det_jacobian(&eta);
386
387                // Chain rule: d(log_p)/d(eta) = d(log_p)/d(theta) * d(theta)/d(eta)
388                let fwd_grad = self.forward_grad(&eta);
389                let grad_eta: Array1<f64> =
390                    Array1::from_shape_fn(dim, |i| grad_theta[i] * fwd_grad[i] + grad_ldj[i]);
391
392                // ELBO contribution: log p(x, theta) + log|det J|
393                let elbo_s = log_p + ldj;
394                elbo_sum += elbo_s;
395
396                // Gradients w.r.t. mu and log_sigma
397                // d(ELBO)/d(mu_i) = grad_eta_i
398                // d(ELBO)/d(log_sigma_i) = grad_eta_i * sigma_i * epsilon_i + 1 (entropy gradient)
399                for i in 0..dim {
400                    grad_mu_sum[i] += grad_eta[i];
401                    grad_log_sigma_sum[i] += grad_eta[i] * sigma[i] * epsilon[i];
402                }
403            }
404
405            let n_s = self.config.num_samples as f64;
406            elbo_sum /= n_s;
407            grad_mu_sum /= n_s;
408            grad_log_sigma_sum /= n_s;
409
410            // Add entropy gradient: d H[q] / d log_sigma_i = 1.0
411            // H[q] = sum_i (0.5 * (1 + log(2*pi)) + log_sigma_i)
412            for i in 0..dim {
413                grad_log_sigma_sum[i] += 1.0;
414            }
415
416            // Include entropy in ELBO
417            let entropy: f64 = (0..dim)
418                .map(|i| 0.5 * (1.0 + (2.0 * PI).ln()) + log_sigma[i])
419                .sum();
420            elbo_sum += entropy;
421
422            elbo_history.push(elbo_sum);
423
424            // Combine gradients into single vector for Adam
425            let mut full_grad = Array1::zeros(n_params);
426            for i in 0..dim {
427                full_grad[i] = grad_mu_sum[i];
428                full_grad[dim + i] = grad_log_sigma_sum[i];
429            }
430
431            // Adam update
432            let direction = adam.update(&full_grad);
433            let lr = self.config.learning_rate;
434            for i in 0..dim {
435                mu[i] += lr * direction[i];
436                log_sigma[i] += lr * direction[dim + i];
437                // Clip log_sigma to prevent numerical issues
438                log_sigma[i] = log_sigma[i].max(-10.0).min(10.0);
439            }
440
441            // Check convergence
442            if elbo_history.len() >= self.config.convergence_window {
443                let n = elbo_history.len();
444                let w = self.config.convergence_window;
445                let recent_avg: f64 =
446                    elbo_history[n - w / 2..n].iter().sum::<f64>() / (w / 2) as f64;
447                let earlier_avg: f64 =
448                    elbo_history[n - w..n - w / 2].iter().sum::<f64>() / (w / 2) as f64;
449                if (recent_avg - earlier_avg).abs() < self.config.tolerance {
450                    converged = true;
451                    break;
452                }
453            }
454        }
455
456        // Compute posterior in constrained space
457        let sigma = log_sigma.mapv(f64::exp);
458        let constrained_means = self.transform_to_constrained(&mu);
459
460        // Approximate constrained std devs via delta method:
461        // Var(theta) approx (d theta/d eta)^2 * Var(eta)
462        let fwd_grad = self.forward_grad(&mu);
463        let constrained_stds = Array1::from_shape_fn(dim, |i| (fwd_grad[i] * sigma[i]).abs());
464
465        Ok(PosteriorResult {
466            means: constrained_means,
467            std_devs: constrained_stds,
468            elbo_history: elbo_history.clone(),
469            iterations: elbo_history.len(),
470            converged,
471            samples: None,
472        })
473    }
474
475    /// Fit full-rank ADVI
476    fn fit_full_rank<F>(&self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
477    where
478        F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
479    {
480        // Parameters: mu (dim) + lower-triangular L (dim*(dim+1)/2)
481        let n_tril = dim * (dim + 1) / 2;
482        let n_params = dim + n_tril;
483        let mut mu = Array1::zeros(dim);
484        // Initialize L as identity (store lower triangular entries)
485        let mut l_entries = Array1::zeros(n_tril);
486        {
487            let mut idx = 0;
488            for row in 0..dim {
489                for col in 0..=row {
490                    if row == col {
491                        l_entries[idx] = 1.0; // diagonal = 1
492                    }
493                    idx += 1;
494                }
495            }
496        }
497
498        let mut adam = AdviAdamState::new(n_params);
499        let mut elbo_history = Vec::with_capacity(self.config.max_iterations);
500        let mut converged = false;
501
502        for iter in 0..self.config.max_iterations {
503            // Reconstruct L matrix from entries
504            let l_mat = tril_to_matrix(dim, &l_entries);
505
506            let mut elbo_sum = 0.0;
507            let mut grad_mu_sum = Array1::zeros(dim);
508            let mut grad_l_sum = Array1::zeros(n_tril);
509
510            for s in 0..self.config.num_samples {
511                let seed = self
512                    .config
513                    .seed
514                    .wrapping_add(iter as u64 * 1000)
515                    .wrapping_add(s as u64);
516                let epsilon = self.generate_epsilon(dim, seed);
517
518                // Reparameterization: eta = mu + L * epsilon
519                let l_eps = l_mat.dot(&epsilon);
520                let eta = &mu + &l_eps;
521
522                // Transform to constrained space
523                let theta = self.transform_to_constrained(&eta);
524
525                // Evaluate log joint and gradient
526                let (log_p, grad_theta) = log_joint(&theta)?;
527
528                // Log-det-Jacobian correction
529                let ldj = self.total_log_det_jacobian(&eta);
530                let grad_ldj = self.grad_log_det_jacobian(&eta);
531
532                // Chain rule: d(log_p)/d(eta) = d(log_p)/d(theta) * d(theta)/d(eta)
533                let fwd_grad = self.forward_grad(&eta);
534                let grad_eta: Array1<f64> =
535                    Array1::from_shape_fn(dim, |i| grad_theta[i] * fwd_grad[i] + grad_ldj[i]);
536
537                let elbo_s = log_p + ldj;
538                elbo_sum += elbo_s;
539
540                // Gradients w.r.t. mu
541                for i in 0..dim {
542                    grad_mu_sum[i] += grad_eta[i];
543                }
544
545                // Gradients w.r.t. L entries
546                // d(eta)/d(L_{ij}) = epsilon_j (when i is the row)
547                let mut idx = 0;
548                for row in 0..dim {
549                    for col in 0..=row {
550                        grad_l_sum[idx] += grad_eta[row] * epsilon[col];
551                        idx += 1;
552                    }
553                }
554            }
555
556            let n_s = self.config.num_samples as f64;
557            elbo_sum /= n_s;
558            grad_mu_sum /= n_s;
559            grad_l_sum /= n_s;
560
561            // Entropy of full-rank Gaussian:
562            // H[q] = 0.5 * d * (1 + log(2*pi)) + sum_i log|L_ii|
563            let mut entropy = 0.5 * dim as f64 * (1.0 + (2.0 * PI).ln());
564            {
565                let mut idx = 0;
566                for row in 0..dim {
567                    for col in 0..=row {
568                        if row == col {
569                            entropy += l_entries[idx].abs().max(1e-15).ln();
570                            // Gradient of entropy w.r.t. L_ii = 1/L_ii
571                            let l_ii = l_entries[idx];
572                            if l_ii.abs() > 1e-15 {
573                                grad_l_sum[idx] += 1.0 / l_ii;
574                            }
575                        }
576                        idx += 1;
577                    }
578                }
579            }
580            elbo_sum += entropy;
581            elbo_history.push(elbo_sum);
582
583            // Combine gradients
584            let mut full_grad = Array1::zeros(n_params);
585            for i in 0..dim {
586                full_grad[i] = grad_mu_sum[i];
587            }
588            for i in 0..n_tril {
589                full_grad[dim + i] = grad_l_sum[i];
590            }
591
592            // Adam update
593            let direction = adam.update(&full_grad);
594            let lr = self.config.learning_rate;
595            for i in 0..dim {
596                mu[i] += lr * direction[i];
597            }
598            for i in 0..n_tril {
599                l_entries[i] += lr * direction[dim + i];
600            }
601
602            // Ensure diagonal of L stays positive (for valid Cholesky)
603            {
604                let mut idx = 0;
605                for row in 0..dim {
606                    for col in 0..=row {
607                        if row == col {
608                            l_entries[idx] = l_entries[idx].abs().max(1e-6);
609                        }
610                        // Clip entries for stability
611                        l_entries[idx] = l_entries[idx].max(-10.0).min(10.0);
612                        idx += 1;
613                    }
614                }
615            }
616
617            // Check convergence
618            if elbo_history.len() >= self.config.convergence_window {
619                let n = elbo_history.len();
620                let w = self.config.convergence_window;
621                let recent_avg: f64 =
622                    elbo_history[n - w / 2..n].iter().sum::<f64>() / (w / 2) as f64;
623                let earlier_avg: f64 =
624                    elbo_history[n - w..n - w / 2].iter().sum::<f64>() / (w / 2) as f64;
625                if (recent_avg - earlier_avg).abs() < self.config.tolerance {
626                    converged = true;
627                    break;
628                }
629            }
630        }
631
632        // Compute posterior statistics
633        let l_mat = tril_to_matrix(dim, &l_entries);
634        let constrained_means = self.transform_to_constrained(&mu);
635
636        // Covariance in unconstrained space: Sigma = L L^T
637        let cov = l_mat.dot(&l_mat.t());
638
639        // Delta-method std devs in constrained space
640        let fwd_grad = self.forward_grad(&mu);
641        let constrained_stds =
642            Array1::from_shape_fn(dim, |i| (fwd_grad[i] * fwd_grad[i] * cov[[i, i]]).sqrt());
643
644        Ok(PosteriorResult {
645            means: constrained_means,
646            std_devs: constrained_stds,
647            elbo_history: elbo_history.clone(),
648            iterations: elbo_history.len(),
649            converged,
650            samples: None,
651        })
652    }
653}
654
655impl VariationalInference for Advi {
656    fn fit<F>(&mut self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
657    where
658        F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
659    {
660        if dim == 0 {
661            return Err(StatsError::InvalidArgument(
662                "Dimension must be at least 1".to_string(),
663            ));
664        }
665        if self.config.num_samples == 0 {
666            return Err(StatsError::InvalidArgument(
667                "num_samples must be at least 1".to_string(),
668            ));
669        }
670        if self.config.learning_rate <= 0.0 {
671            return Err(StatsError::InvalidArgument(
672                "learning_rate must be positive".to_string(),
673            ));
674        }
675
676        match self.config.approximation {
677            AdviApproximation::MeanField => self.fit_mean_field(log_joint, dim),
678            AdviApproximation::FullRank => self.fit_full_rank(log_joint, dim),
679        }
680    }
681}
682
683// ============================================================================
684// Helper: lower-triangular entries <-> matrix
685// ============================================================================
686
687/// Reconstruct a dim x dim lower-triangular matrix from flat entries
688fn tril_to_matrix(dim: usize, entries: &Array1<f64>) -> Array2<f64> {
689    let mut mat = Array2::zeros((dim, dim));
690    let mut idx = 0;
691    for row in 0..dim {
692        for col in 0..=row {
693            mat[[row, col]] = entries[idx];
694            idx += 1;
695        }
696    }
697    mat
698}
699
700// ============================================================================
701// Tests
702// ============================================================================
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707
708    /// Test: ADVI recovers the posterior mean of a 1D Gaussian with known
709    /// conjugate prior: N(mu | mu0, sigma0^2) * N(data | mu, sigma^2)
710    #[test]
711    fn test_advi_gaussian_posterior_recovery() {
712        // Prior: N(0, 1), Likelihood: data ~ N(mu, 1), observed data mean = 3.0, n = 10
713        // Posterior: N(mu | 30/11, 1/11)
714        let data_mean = 3.0_f64;
715        let n_data = 10.0_f64;
716        let prior_mean = 0.0_f64;
717        let prior_var = 1.0_f64;
718        let lik_var = 1.0_f64;
719
720        let config = AdviConfig {
721            approximation: AdviApproximation::MeanField,
722            transforms: vec![AdviTransform::Identity],
723            num_samples: 20,
724            learning_rate: 0.05,
725            max_iterations: 3000,
726            tolerance: 1e-5,
727            seed: 123,
728            convergence_window: 100,
729        };
730
731        let mut advi = Advi::new(config);
732        let result = advi
733            .fit(
734                move |theta: &Array1<f64>| {
735                    let mu = theta[0];
736                    // Log prior: N(mu | 0, 1)
737                    let log_prior = -0.5 * (mu - prior_mean).powi(2) / prior_var;
738                    // Log likelihood: sum of N(x_i | mu, 1) = -n/2 * (mu - data_mean)^2
739                    let log_lik = -n_data / 2.0 * (mu - data_mean).powi(2) / lik_var;
740                    let log_p = log_prior + log_lik;
741                    // Gradient
742                    let grad_prior = -(mu - prior_mean) / prior_var;
743                    let grad_lik = -n_data * (mu - data_mean) / lik_var;
744                    let grad = Array1::from_vec(vec![grad_prior + grad_lik]);
745                    Ok((log_p, grad))
746                },
747                1,
748            )
749            .expect("ADVI should not fail");
750
751        let expected_mean = (n_data * data_mean / lik_var + prior_mean / prior_var)
752            / (n_data / lik_var + 1.0 / prior_var);
753        let expected_std = (1.0 / (n_data / lik_var + 1.0 / prior_var)).sqrt();
754
755        assert!(
756            (result.means[0] - expected_mean).abs() < 0.3,
757            "Mean should be close to {}, got {}",
758            expected_mean,
759            result.means[0]
760        );
761        assert!(
762            (result.std_devs[0] - expected_std).abs() < 0.2,
763            "Std should be close to {}, got {}",
764            expected_std,
765            result.std_devs[0]
766        );
767    }
768
769    /// Test: ELBO increases (or at least does not decrease on average) over iterations
770    #[test]
771    fn test_advi_elbo_increases() {
772        let config = AdviConfig {
773            approximation: AdviApproximation::MeanField,
774            transforms: vec![AdviTransform::Identity, AdviTransform::Identity],
775            num_samples: 15,
776            learning_rate: 0.02,
777            max_iterations: 500,
778            tolerance: 1e-6,
779            seed: 77,
780            convergence_window: 50,
781        };
782
783        let mut advi = Advi::new(config);
784        let result = advi
785            .fit(
786                |theta: &Array1<f64>| {
787                    // Simple 2D Gaussian target: N([1, 2], I)
788                    let diff0 = theta[0] - 1.0;
789                    let diff1 = theta[1] - 2.0;
790                    let log_p = -0.5 * (diff0 * diff0 + diff1 * diff1);
791                    let grad = Array1::from_vec(vec![-diff0, -diff1]);
792                    Ok((log_p, grad))
793                },
794                2,
795            )
796            .expect("ADVI should succeed");
797
798        // Check that late-stage ELBO is higher than early-stage
799        let n = result.elbo_history.len();
800        assert!(n > 100, "Should run at least 100 iterations");
801        let early_avg: f64 = result.elbo_history[..50].iter().sum::<f64>() / 50.0;
802        let late_avg: f64 = result.elbo_history[n - 50..].iter().sum::<f64>() / 50.0;
803        assert!(
804            late_avg > early_avg - 1.0,
805            "Late ELBO ({}) should be higher than early ({})",
806            late_avg,
807            early_avg
808        );
809    }
810
811    /// Test: Mean-field vs full-rank comparison
812    /// Full-rank should achieve at least as good an ELBO as mean-field
813    /// on a correlated target
814    #[test]
815    fn test_advi_mean_field_vs_full_rank() {
816        // Correlated 2D Gaussian target: rho = 0.8
817        let rho = 0.8_f64;
818        let log_joint = move |theta: &Array1<f64>| {
819            let x = theta[0];
820            let y = theta[1];
821            let det = 1.0 - rho * rho;
822            let log_p =
823                -0.5 / det * (x * x - 2.0 * rho * x * y + y * y) - 0.5 * (2.0 * PI * det).ln();
824            let gx = -1.0 / det * (x - rho * y);
825            let gy = -1.0 / det * (y - rho * x);
826            Ok((log_p, Array1::from_vec(vec![gx, gy])))
827        };
828
829        // Mean-field
830        let mf_config = AdviConfig {
831            approximation: AdviApproximation::MeanField,
832            num_samples: 20,
833            learning_rate: 0.02,
834            max_iterations: 2000,
835            tolerance: 1e-5,
836            seed: 42,
837            convergence_window: 100,
838            ..Default::default()
839        };
840        let mut mf_advi = Advi::new(mf_config);
841        let mf_result = mf_advi.fit(log_joint, 2).expect("MF should succeed");
842
843        // Full-rank
844        let fr_config = AdviConfig {
845            approximation: AdviApproximation::FullRank,
846            num_samples: 20,
847            learning_rate: 0.02,
848            max_iterations: 2000,
849            tolerance: 1e-5,
850            seed: 42,
851            convergence_window: 100,
852            ..Default::default()
853        };
854        let mut fr_advi = Advi::new(fr_config);
855        let fr_result = fr_advi.fit(log_joint, 2).expect("FR should succeed");
856
857        let mf_final_elbo = mf_result
858            .elbo_history
859            .last()
860            .copied()
861            .unwrap_or(f64::NEG_INFINITY);
862        let fr_final_elbo = fr_result
863            .elbo_history
864            .last()
865            .copied()
866            .unwrap_or(f64::NEG_INFINITY);
867
868        // Full-rank should do at least as well (with some tolerance for stochasticity)
869        assert!(
870            fr_final_elbo > mf_final_elbo - 1.0,
871            "Full-rank ELBO ({}) should be >= mean-field ELBO ({}) minus tolerance",
872            fr_final_elbo,
873            mf_final_elbo
874        );
875    }
876
877    /// Test: ADVI with log transform recovers a positive parameter
878    #[test]
879    fn test_advi_log_transform() {
880        // Target: Gamma(3, 1) = log-concave for shape >= 1
881        // mode = (shape - 1) / rate = 2.0
882        let config = AdviConfig {
883            approximation: AdviApproximation::MeanField,
884            transforms: vec![AdviTransform::Log],
885            num_samples: 20,
886            learning_rate: 0.01,
887            max_iterations: 3000,
888            tolerance: 1e-5,
889            seed: 55,
890            convergence_window: 100,
891        };
892
893        let mut advi = Advi::new(config);
894        let result = advi
895            .fit(
896                |theta: &Array1<f64>| {
897                    let x = theta[0];
898                    if x <= 0.0 {
899                        return Ok((f64::NEG_INFINITY, Array1::zeros(1)));
900                    }
901                    // Gamma(3, 1): log p(x) = (3-1)*ln(x) - x - ln(Gamma(3))
902                    let log_p = 2.0 * x.ln() - x - (2.0_f64).ln(); // Gamma(3) = 2! = 2
903                    let grad = Array1::from_vec(vec![2.0 / x - 1.0]);
904                    Ok((log_p, grad))
905                },
906                1,
907            )
908            .expect("ADVI with log transform should succeed");
909
910        // Gamma(3,1) mean = 3, mode = 2
911        assert!(
912            result.means[0] > 0.0,
913            "Mean should be positive with log transform"
914        );
915        assert!(
916            (result.means[0] - 3.0).abs() < 1.5,
917            "Mean should be near 3 (Gamma(3,1) mean), got {}",
918            result.means[0]
919        );
920    }
921
922    /// Test: dimension validation
923    #[test]
924    fn test_advi_zero_dim_error() {
925        let mut advi = Advi::new(AdviConfig::default());
926        let result = advi.fit(|_theta: &Array1<f64>| Ok((0.0, Array1::zeros(0))), 0);
927        assert!(result.is_err());
928    }
929
930    /// Test: transform forward-inverse roundtrip
931    #[test]
932    fn test_transform_roundtrip() {
933        let transforms = vec![
934            AdviTransform::Identity,
935            AdviTransform::Log,
936            AdviTransform::Logit,
937            AdviTransform::Bounded {
938                lower: -2.0,
939                upper: 5.0,
940            },
941        ];
942        let test_vals = vec![1.5, 2.0, 0.3, 1.0];
943
944        for (t, v) in transforms.iter().zip(test_vals.iter()) {
945            let eta = t.inverse(*v).expect("inverse should succeed");
946            let recovered = t.forward(eta);
947            assert!(
948                (recovered - v).abs() < 1e-10,
949                "Roundtrip failed for {:?}: {} -> {} -> {}",
950                t,
951                v,
952                eta,
953                recovered
954            );
955        }
956    }
957
958    /// Test: log-det-Jacobian is nonzero for non-identity transforms
959    #[test]
960    fn test_log_det_jacobian_nonzero() {
961        let transforms = vec![
962            AdviTransform::Log,
963            AdviTransform::Logit,
964            AdviTransform::Bounded {
965                lower: 0.0,
966                upper: 10.0,
967            },
968        ];
969        for t in &transforms {
970            let ldj = t.log_det_jacobian(0.5);
971            assert!(
972                ldj.is_finite(),
973                "Log-det-Jacobian should be finite for {:?}",
974                t
975            );
976        }
977    }
978}