Skip to main content

scirs2_stats/bayesian/variational/
svi.rs

1//! Stochastic Variational Inference (SVI)
2//!
3//! This module implements scalable variational inference methods that use
4//! stochastic optimization for large datasets. Key features:
5//!
6//! - Mini-batch ELBO estimation
7//! - Natural gradient updates
8//! - Adam-like learning rate scheduling
9//! - Support for both mean-field and full-rank Gaussian approximations
10
11use crate::error::{StatsError, StatsResult as Result};
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
13use scirs2_core::validation::*;
14use std::f64::consts::PI;
15
16use super::{digamma, lgamma, FullRankGaussian, MeanFieldGaussian, VariationalDiagnostics};
17
18// ============================================================================
19// Learning Rate Schedules
20// ============================================================================
21
22/// Learning rate schedule for SVI optimization
23#[derive(Debug, Clone)]
24pub enum LearningRateSchedule {
25    /// Constant learning rate
26    Constant {
27        /// The fixed learning rate
28        lr: f64,
29    },
30    /// Robbins-Monro schedule: lr_t = lr_0 / (1 + decay * t)
31    RobbinsMonro {
32        /// Initial learning rate
33        lr0: f64,
34        /// Decay factor
35        decay: f64,
36    },
37    /// Exponential decay: lr_t = lr_0 * gamma^t
38    ExponentialDecay {
39        /// Initial learning rate
40        lr0: f64,
41        /// Decay multiplier per step
42        gamma: f64,
43    },
44    /// Adam-like adaptive learning rate
45    Adam {
46        /// Base learning rate
47        lr: f64,
48        /// First moment decay (beta1)
49        beta1: f64,
50        /// Second moment decay (beta2)
51        beta2: f64,
52        /// Numerical stability constant
53        epsilon: f64,
54    },
55}
56
57impl LearningRateSchedule {
58    /// Get the learning rate at iteration t
59    pub fn get_lr(&self, t: usize) -> f64 {
60        match self {
61            LearningRateSchedule::Constant { lr } => *lr,
62            LearningRateSchedule::RobbinsMonro { lr0, decay } => lr0 / (1.0 + decay * t as f64),
63            LearningRateSchedule::ExponentialDecay { lr0, gamma } => lr0 * gamma.powi(t as i32),
64            LearningRateSchedule::Adam { lr, .. } => {
65                // Base learning rate; actual Adam adjustment is done in the optimizer
66                *lr
67            }
68        }
69    }
70
71    /// Create a default Adam schedule
72    pub fn default_adam() -> Self {
73        LearningRateSchedule::Adam {
74            lr: 0.01,
75            beta1: 0.9,
76            beta2: 0.999,
77            epsilon: 1e-8,
78        }
79    }
80
81    /// Create a default Robbins-Monro schedule
82    pub fn default_robbins_monro() -> Self {
83        LearningRateSchedule::RobbinsMonro {
84            lr0: 0.1,
85            decay: 0.01,
86        }
87    }
88}
89
90// ============================================================================
91// Adam Optimizer State
92// ============================================================================
93
94/// Adam optimizer state for adaptive learning rates
95#[derive(Debug, Clone)]
96pub struct AdamState {
97    /// First moment estimates
98    pub m: Array1<f64>,
99    /// Second moment estimates
100    pub v: Array1<f64>,
101    /// Beta1 parameter (first moment decay)
102    pub beta1: f64,
103    /// Beta2 parameter (second moment decay)
104    pub beta2: f64,
105    /// Numerical stability epsilon
106    pub epsilon: f64,
107    /// Base learning rate
108    pub lr: f64,
109    /// Current time step
110    pub t: usize,
111}
112
113impl AdamState {
114    /// Create a new Adam state for parameters of given dimension
115    pub fn new(dim: usize, lr: f64, beta1: f64, beta2: f64, epsilon: f64) -> Result<Self> {
116        check_positive(dim, "dim")?;
117        check_positive(lr, "lr")?;
118        check_positive(epsilon, "epsilon")?;
119
120        Ok(Self {
121            m: Array1::zeros(dim),
122            v: Array1::zeros(dim),
123            beta1,
124            beta2,
125            epsilon,
126            lr,
127            t: 0,
128        })
129    }
130
131    /// Compute Adam update for a given gradient
132    pub fn compute_update(&mut self, gradient: &Array1<f64>) -> Result<Array1<f64>> {
133        if gradient.len() != self.m.len() {
134            return Err(StatsError::DimensionMismatch(format!(
135                "gradient length ({}) must match state dimension ({})",
136                gradient.len(),
137                self.m.len()
138            )));
139        }
140
141        self.t += 1;
142
143        // Update biased first moment estimate
144        self.m = &self.m * self.beta1 + gradient * (1.0 - self.beta1);
145
146        // Update biased second raw moment estimate
147        self.v = &self.v * self.beta2 + &gradient.mapv(|g| g * g) * (1.0 - self.beta2);
148
149        // Compute bias-corrected first moment estimate
150        let m_hat = &self.m / (1.0 - self.beta1.powi(self.t as i32));
151
152        // Compute bias-corrected second raw moment estimate
153        let v_hat = &self.v / (1.0 - self.beta2.powi(self.t as i32));
154
155        // Compute update
156        let update = &m_hat / &v_hat.mapv(|vi| vi.sqrt() + self.epsilon) * self.lr;
157
158        Ok(update)
159    }
160
161    /// Reset the optimizer state
162    pub fn reset(&mut self) {
163        self.m.fill(0.0);
164        self.v.fill(0.0);
165        self.t = 0;
166    }
167}
168
169// ============================================================================
170// Natural Gradient Computations
171// ============================================================================
172
173/// Natural gradient parameters for exponential family distributions
174#[derive(Debug, Clone)]
175pub struct NaturalGradientParams {
176    /// Natural parameters (eta) for the variational distribution
177    pub eta: Array1<f64>,
178    /// Fisher information matrix (or its approximation)
179    /// Stored as diagonal for mean-field case
180    pub fisher_diag: Array1<f64>,
181}
182
183impl NaturalGradientParams {
184    /// Create natural gradient parameters for a mean-field Gaussian
185    ///
186    /// For a Gaussian q(z; mu, sigma^2):
187    ///   eta_1 = mu / sigma^2   (natural parameter 1)
188    ///   eta_2 = -1 / (2*sigma^2) (natural parameter 2)
189    ///
190    /// Fisher information is 1/sigma^2 for the mean parameter
191    /// and 2/sigma^4 for the variance parameter
192    pub fn from_mean_field(mf: &MeanFieldGaussian) -> Self {
193        let dim = mf.dim;
194        let stds = mf.stds();
195        let vars = mf.variances();
196
197        // Natural parameters: [mu/sigma^2, -1/(2*sigma^2)]
198        let mut eta = Array1::zeros(2 * dim);
199        let mut fisher_diag = Array1::zeros(2 * dim);
200
201        for i in 0..dim {
202            // eta_1 = mu / sigma^2
203            eta[i] = mf.means[i] / vars[i];
204            // eta_2 = -1 / (2*sigma^2)
205            eta[dim + i] = -1.0 / (2.0 * vars[i]);
206
207            // Fisher diagonal
208            fisher_diag[i] = 1.0 / vars[i]; // For mean
209            fisher_diag[dim + i] = 2.0 / (stds[i].powi(4)); // For variance
210        }
211
212        Self { eta, fisher_diag }
213    }
214
215    /// Convert natural parameters back to mean/std parameterization
216    pub fn to_mean_field(&self) -> Result<MeanFieldGaussian> {
217        let dim = self.eta.len() / 2;
218        if dim == 0 {
219            return Err(StatsError::InvalidArgument(
220                "Natural parameters must have positive dimension".to_string(),
221            ));
222        }
223
224        let mut means = Array1::zeros(dim);
225        let mut log_stds = Array1::zeros(dim);
226
227        for i in 0..dim {
228            let eta2 = self.eta[dim + i];
229            if eta2 >= 0.0 {
230                return Err(StatsError::InvalidArgument(format!(
231                    "eta_2[{}] = {} must be negative for valid Gaussian",
232                    i, eta2
233                )));
234            }
235            let var = -1.0 / (2.0 * eta2);
236            let mean = self.eta[i] * var;
237            means[i] = mean;
238            log_stds[i] = 0.5 * var.ln();
239        }
240
241        MeanFieldGaussian::from_params(means, log_stds)
242    }
243
244    /// Compute natural gradient update: update = Fisher^{-1} * euclidean_grad
245    /// For diagonal Fisher, this is element-wise division
246    pub fn natural_gradient_update(&self, euclidean_grad: &Array1<f64>) -> Result<Array1<f64>> {
247        if euclidean_grad.len() != self.fisher_diag.len() {
248            return Err(StatsError::DimensionMismatch(format!(
249                "gradient length ({}) must match parameter dimension ({})",
250                euclidean_grad.len(),
251                self.fisher_diag.len()
252            )));
253        }
254
255        let mut nat_grad = Array1::zeros(euclidean_grad.len());
256        for i in 0..euclidean_grad.len() {
257            if self.fisher_diag[i].abs() < 1e-15 {
258                nat_grad[i] = 0.0; // Avoid division by zero
259            } else {
260                nat_grad[i] = euclidean_grad[i] / self.fisher_diag[i];
261            }
262        }
263
264        Ok(nat_grad)
265    }
266}
267
268// ============================================================================
269// SVI Configuration
270// ============================================================================
271
272/// Configuration for Stochastic Variational Inference
273#[derive(Debug, Clone)]
274pub struct SviConfig {
275    /// Maximum number of iterations
276    pub max_iter: usize,
277    /// Mini-batch size (number of data points per batch)
278    pub batch_size: usize,
279    /// Learning rate schedule
280    pub lr_schedule: LearningRateSchedule,
281    /// Convergence tolerance (on ELBO)
282    pub tol: f64,
283    /// Number of Monte Carlo samples for ELBO estimation
284    pub n_mc_samples: usize,
285    /// Whether to use natural gradients
286    pub use_natural_gradient: bool,
287    /// How often to compute full ELBO for diagnostics (0 = never)
288    pub diagnostic_interval: usize,
289    /// Gradient clipping threshold (0 = no clipping)
290    pub grad_clip: f64,
291    /// Seed for reproducibility (used for batch selection)
292    pub seed: u64,
293}
294
295impl Default for SviConfig {
296    fn default() -> Self {
297        Self {
298            max_iter: 1000,
299            batch_size: 32,
300            lr_schedule: LearningRateSchedule::default_adam(),
301            tol: 1e-4,
302            n_mc_samples: 1,
303            use_natural_gradient: false,
304            diagnostic_interval: 50,
305            grad_clip: 10.0,
306            seed: 42,
307        }
308    }
309}
310
311// ============================================================================
312// Stochastic Variational Inference
313// ============================================================================
314
315/// Stochastic Variational Inference (SVI)
316///
317/// Implements scalable variational inference using stochastic gradient
318/// ascent on the ELBO. Supports:
319/// - Mini-batch ELBO estimation for large datasets
320/// - Natural gradient updates for faster convergence
321/// - Adam adaptive learning rates
322/// - Gradient clipping for stability
323///
324/// The model assumes a mean-field Gaussian posterior approximation
325/// q(z) = prod_i N(z_i; mu_i, sigma_i^2)
326///
327/// The user provides a log joint density function log p(z, x_batch) that
328/// takes the latent variables z and a mini-batch of data.
329#[derive(Debug, Clone)]
330pub struct StochasticVI {
331    /// Variational distribution (mean-field Gaussian)
332    pub variational: MeanFieldGaussian,
333    /// Configuration
334    pub config: SviConfig,
335    /// Diagnostics
336    pub diagnostics: VariationalDiagnostics,
337    /// Adam optimizer state (if using Adam)
338    adam_state: Option<AdamState>,
339}
340
341impl StochasticVI {
342    /// Create a new SVI instance
343    pub fn new(dim: usize, config: SviConfig) -> Result<Self> {
344        check_positive(dim, "dim")?;
345
346        let variational = MeanFieldGaussian::new(dim)?;
347
348        let adam_state = if let LearningRateSchedule::Adam {
349            lr,
350            beta1,
351            beta2,
352            epsilon,
353        } = &config.lr_schedule
354        {
355            Some(AdamState::new(2 * dim, *lr, *beta1, *beta2, *epsilon)?)
356        } else {
357            None
358        };
359
360        Ok(Self {
361            variational,
362            config,
363            diagnostics: VariationalDiagnostics::new(),
364            adam_state,
365        })
366    }
367
368    /// Run SVI optimization with a log joint density function
369    ///
370    /// # Arguments
371    /// * `data` - Full dataset (rows are observations)
372    /// * `log_joint` - Function computing log p(z, x_batch) given latent variables
373    ///   and a batch of data. Returns (log_prob, gradient_wrt_z).
374    /// * `n_total` - Total number of data points (for scaling mini-batch ELBO)
375    ///
376    /// # Returns
377    /// * The optimized variational distribution and diagnostics
378    pub fn fit<F>(
379        &mut self,
380        data: ArrayView2<f64>,
381        log_joint: F,
382        n_total: usize,
383    ) -> Result<SviResult>
384    where
385        F: Fn(&Array1<f64>, ArrayView2<f64>) -> Result<(f64, Array1<f64>)>,
386    {
387        checkarray_finite(&data, "data")?;
388        check_positive(n_total, "n_total")?;
389
390        let (n_data, _) = data.dim();
391        let batch_size = self.config.batch_size.min(n_data);
392        let scale_factor = n_total as f64 / batch_size as f64;
393
394        // Simple deterministic batch cycling with seed-based offset
395        let offset = (self.config.seed % n_data as u64) as usize;
396
397        for iter in 0..self.config.max_iter {
398            // Select mini-batch (deterministic cycling with offset)
399            let batch_start = (offset + iter * batch_size) % n_data;
400            let batch_end = (batch_start + batch_size).min(n_data);
401            let actual_batch_size = batch_end - batch_start;
402
403            let batch = data.slice(scirs2_core::ndarray::s![batch_start..batch_end, ..]);
404
405            // Estimate ELBO gradient using reparameterization trick
406            let (elbo_estimate, grad) = self.estimate_elbo_gradient(
407                batch,
408                &log_joint,
409                scale_factor * (actual_batch_size as f64 / batch_size as f64),
410            )?;
411
412            // Record diagnostics
413            self.diagnostics.record_elbo(elbo_estimate);
414            let grad_norm = grad.dot(&grad).sqrt();
415            self.diagnostics.record_gradient_norm(grad_norm);
416
417            // Apply gradient clipping if configured
418            let clipped_grad = if self.config.grad_clip > 0.0 && grad_norm > self.config.grad_clip {
419                &grad * (self.config.grad_clip / grad_norm)
420            } else {
421                grad
422            };
423
424            // Compute update (natural gradient or Euclidean)
425            let update = if self.config.use_natural_gradient {
426                let nat_params = NaturalGradientParams::from_mean_field(&self.variational);
427                nat_params.natural_gradient_update(&clipped_grad)?
428            } else {
429                clipped_grad
430            };
431
432            // Apply learning rate and update parameters
433            let lr = self.config.lr_schedule.get_lr(iter);
434            let current_params = self.variational.get_params();
435
436            let new_params = if let Some(ref mut adam) = self.adam_state {
437                let adam_update = adam.compute_update(&update)?;
438                &current_params + &adam_update
439            } else {
440                &current_params + &(&update * lr)
441            };
442
443            // Track parameter change
444            let param_change = (&new_params - &current_params).mapv(|x| x * x).sum().sqrt();
445            self.diagnostics.record_param_change(param_change);
446
447            self.variational.set_params(&new_params)?;
448
449            // Check convergence
450            if iter > 10 && self.diagnostics.check_elbo_convergence(self.config.tol) {
451                self.diagnostics.converged = true;
452                break;
453            }
454        }
455
456        Ok(SviResult {
457            variational: self.variational.clone(),
458            diagnostics: self.diagnostics.clone(),
459        })
460    }
461
462    /// Estimate ELBO and its gradient using Monte Carlo samples
463    fn estimate_elbo_gradient<F>(
464        &self,
465        batch: ArrayView2<f64>,
466        log_joint: &F,
467        scale_factor: f64,
468    ) -> Result<(f64, Array1<f64>)>
469    where
470        F: Fn(&Array1<f64>, ArrayView2<f64>) -> Result<(f64, Array1<f64>)>,
471    {
472        let dim = self.variational.dim;
473        let n_samples = self.config.n_mc_samples.max(1);
474
475        let mut total_elbo = 0.0;
476        let mut total_grad = Array1::zeros(2 * dim);
477
478        for s in 0..n_samples {
479            // Generate epsilon ~ N(0, I) using simple deterministic approximation
480            // (for production, you'd use a proper RNG here)
481            let epsilon =
482                generate_standard_normal(dim, s as u64 + self.diagnostics.n_iterations as u64);
483
484            // Reparameterization: z = mu + sigma * epsilon
485            let z = self.variational.sample(&epsilon)?;
486
487            // Compute log joint and its gradient
488            let (log_p, grad_z) = log_joint(&z, batch)?;
489
490            // Scale for mini-batch
491            let scaled_log_p = log_p * scale_factor;
492            let scaled_grad_z = &grad_z * scale_factor;
493
494            // Compute log q(z) and entropy gradient
495            let log_q = self.variational.log_prob(&z)?;
496
497            // ELBO = E[log p(z, x) - log q(z)]
498            total_elbo += scaled_log_p - log_q;
499
500            // Gradient of ELBO wrt variational parameters (mu, log_sigma)
501            // d ELBO / d mu = d log_p / d z (through reparameterization)
502            // d ELBO / d log_sigma = d log_p / d z * epsilon * sigma + 1 (entropy gradient)
503            let stds = self.variational.stds();
504            for i in 0..dim {
505                // Gradient wrt mean
506                total_grad[i] += scaled_grad_z[i];
507                // Gradient wrt log_std: chain rule + entropy
508                total_grad[dim + i] += scaled_grad_z[i] * epsilon[i] * stds[i] + 1.0;
509            }
510
511            // Subtract gradient of log q
512            for i in 0..dim {
513                let diff = z[i] - self.variational.means[i];
514                let var = stds[i] * stds[i];
515                // d log q / d mu = (z - mu) / sigma^2
516                total_grad[i] -= diff / var;
517                // d log q / d log_sigma = ((z-mu)^2 / sigma^2 - 1) * sigma ... simplified
518                total_grad[dim + i] -= diff * diff / var - 1.0;
519            }
520        }
521
522        // Average over samples
523        total_elbo /= n_samples as f64;
524        total_grad /= n_samples as f64;
525
526        Ok((total_elbo, total_grad))
527    }
528
529    /// Get the current variational distribution
530    pub fn variational_distribution(&self) -> &MeanFieldGaussian {
531        &self.variational
532    }
533
534    /// Get diagnostics
535    pub fn diagnostics(&self) -> &VariationalDiagnostics {
536        &self.diagnostics
537    }
538
539    /// Reset the optimizer state (useful for warm restarts)
540    pub fn reset_optimizer(&mut self) {
541        if let Some(ref mut adam) = self.adam_state {
542            adam.reset();
543        }
544        self.diagnostics = VariationalDiagnostics::new();
545    }
546}
547
548/// Results from SVI optimization
549#[derive(Debug, Clone)]
550pub struct SviResult {
551    /// Optimized variational distribution
552    pub variational: MeanFieldGaussian,
553    /// Optimization diagnostics
554    pub diagnostics: VariationalDiagnostics,
555}
556
557impl SviResult {
558    /// Get posterior means
559    pub fn posterior_means(&self) -> &Array1<f64> {
560        &self.variational.means
561    }
562
563    /// Get posterior standard deviations
564    pub fn posterior_stds(&self) -> Array1<f64> {
565        self.variational.stds()
566    }
567
568    /// Compute approximate credible intervals
569    pub fn credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
570        check_probability(confidence, "confidence")?;
571        let alpha = (1.0 - confidence) / 2.0;
572        let z_critical = super::normal_ppf(1.0 - alpha)?;
573
574        let dim = self.variational.dim;
575        let mut intervals = Array2::zeros((dim, 2));
576        let stds = self.variational.stds();
577
578        for i in 0..dim {
579            intervals[[i, 0]] = self.variational.means[i] - z_critical * stds[i];
580            intervals[[i, 1]] = self.variational.means[i] + z_critical * stds[i];
581        }
582
583        Ok(intervals)
584    }
585}
586
587// ============================================================================
588// SVI for Bayesian Linear Regression
589// ============================================================================
590
591/// SVI-specialized Bayesian linear regression
592///
593/// Uses stochastic variational inference to fit a Bayesian linear regression
594/// model, making it suitable for large datasets that don't fit in memory.
595///
596/// Model: y = X * beta + epsilon, epsilon ~ N(0, sigma^2)
597/// Prior: beta ~ N(0, prior_var * I), sigma^2 ~ InvGamma(a, b)
598#[derive(Debug, Clone)]
599pub struct SviBayesianRegression {
600    /// Variational mean for coefficients
601    pub mean_beta: Array1<f64>,
602    /// Variational log std for coefficients
603    pub log_std_beta: Array1<f64>,
604    /// Variational parameters for noise precision: shape and rate of Gamma
605    pub shape_tau: f64,
606    pub rate_tau: f64,
607    /// Prior variance for coefficients
608    pub prior_var: f64,
609    /// Prior shape for noise precision
610    pub prior_shape: f64,
611    /// Prior rate for noise precision
612    pub prior_rate: f64,
613    /// Number of features
614    pub n_features: usize,
615    /// SVI configuration
616    pub config: SviConfig,
617}
618
619impl SviBayesianRegression {
620    /// Create a new SVI Bayesian regression model
621    pub fn new(n_features: usize, config: SviConfig) -> Result<Self> {
622        check_positive(n_features, "n_features")?;
623
624        Ok(Self {
625            mean_beta: Array1::zeros(n_features),
626            log_std_beta: Array1::zeros(n_features),
627            shape_tau: 1.0,
628            rate_tau: 1.0,
629            prior_var: 100.0,
630            prior_shape: 1e-3,
631            prior_rate: 1e-3,
632            n_features,
633            config,
634        })
635    }
636
637    /// Set prior parameters
638    pub fn with_priors(
639        mut self,
640        prior_var: f64,
641        prior_shape: f64,
642        prior_rate: f64,
643    ) -> Result<Self> {
644        check_positive(prior_var, "prior_var")?;
645        check_positive(prior_shape, "prior_shape")?;
646        check_positive(prior_rate, "prior_rate")?;
647        self.prior_var = prior_var;
648        self.prior_shape = prior_shape;
649        self.prior_rate = prior_rate;
650        Ok(self)
651    }
652
653    /// Fit using SVI with mini-batches
654    pub fn fit(&mut self, x: ArrayView2<f64>, y: ArrayView1<f64>) -> Result<SviRegressionResult> {
655        checkarray_finite(&x, "x")?;
656        checkarray_finite(&y, "y")?;
657
658        let (n_samples, n_features) = x.dim();
659        if y.len() != n_samples {
660            return Err(StatsError::DimensionMismatch(format!(
661                "y length ({}) must match x rows ({})",
662                y.len(),
663                n_samples
664            )));
665        }
666        if n_features != self.n_features {
667            return Err(StatsError::DimensionMismatch(format!(
668                "x features ({}) must match model features ({})",
669                n_features, self.n_features
670            )));
671        }
672
673        let batch_size = self.config.batch_size.min(n_samples);
674        let scale_factor = n_samples as f64 / batch_size as f64;
675        let offset = (self.config.seed % n_samples as u64) as usize;
676
677        // Initialize Adam state for all parameters
678        // Parameters: [mean_beta (d), log_std_beta (d), log_shape_tau, log_rate_tau]
679        let n_params = 2 * self.n_features + 2;
680        let mut adam_state = if let LearningRateSchedule::Adam {
681            lr,
682            beta1,
683            beta2,
684            epsilon,
685        } = &self.config.lr_schedule
686        {
687            Some(AdamState::new(n_params, *lr, *beta1, *beta2, *epsilon)?)
688        } else {
689            None
690        };
691
692        let mut diagnostics = VariationalDiagnostics::new();
693
694        for iter in 0..self.config.max_iter {
695            // Select mini-batch
696            let batch_start = (offset + iter * batch_size) % n_samples;
697            let batch_end = (batch_start + batch_size).min(n_samples);
698
699            let x_batch = x.slice(scirs2_core::ndarray::s![batch_start..batch_end, ..]);
700            let y_batch = y.slice(scirs2_core::ndarray::s![batch_start..batch_end]);
701
702            // Compute stochastic ELBO gradient
703            let (elbo, grad) =
704                self.compute_stochastic_elbo_grad(x_batch, y_batch, scale_factor, iter as u64)?;
705
706            diagnostics.record_elbo(elbo);
707
708            let grad_norm = grad.dot(&grad).sqrt();
709            diagnostics.record_gradient_norm(grad_norm);
710
711            // Clip gradient
712            let clipped_grad = if self.config.grad_clip > 0.0 && grad_norm > self.config.grad_clip {
713                &grad * (self.config.grad_clip / grad_norm)
714            } else {
715                grad
716            };
717
718            // Get current parameters
719            let current_params = self.get_params();
720
721            // Apply update
722            let new_params = if let Some(ref mut adam) = adam_state {
723                let update = adam.compute_update(&clipped_grad)?;
724                &current_params + &update
725            } else {
726                let lr = self.config.lr_schedule.get_lr(iter);
727                &current_params + &(&clipped_grad * lr)
728            };
729
730            let param_change = (&new_params - &current_params).mapv(|x| x * x).sum().sqrt();
731            diagnostics.record_param_change(param_change);
732
733            self.set_params(&new_params)?;
734
735            // Check convergence
736            if iter > 20 && diagnostics.check_elbo_convergence(self.config.tol) {
737                diagnostics.converged = true;
738                break;
739            }
740        }
741
742        Ok(SviRegressionResult {
743            mean_beta: self.mean_beta.clone(),
744            std_beta: self.log_std_beta.mapv(f64::exp),
745            shape_tau: self.shape_tau,
746            rate_tau: self.rate_tau,
747            diagnostics,
748        })
749    }
750
751    /// Compute stochastic ELBO and gradient for a mini-batch
752    fn compute_stochastic_elbo_grad(
753        &self,
754        x_batch: ArrayView2<f64>,
755        y_batch: ArrayView1<f64>,
756        scale_factor: f64,
757        seed: u64,
758    ) -> Result<(f64, Array1<f64>)> {
759        let n_batch = x_batch.nrows();
760        let d = self.n_features;
761        let n_params = 2 * d + 2;
762
763        let std_beta = self.log_std_beta.mapv(f64::exp);
764        let expected_tau = self.shape_tau / self.rate_tau;
765        let expected_log_tau = digamma(self.shape_tau) - self.rate_tau.ln();
766
767        // Sample beta using reparameterization trick
768        let epsilon = generate_standard_normal(d, seed);
769        let beta_sample = &self.mean_beta + &(&std_beta * &epsilon);
770
771        // Compute residuals
772        let y_pred = x_batch.dot(&beta_sample);
773        let residuals = &y_batch.to_owned() - &y_pred;
774        let sse = residuals.dot(&residuals);
775
776        // Scaled likelihood term
777        let likelihood = scale_factor
778            * (0.5 * n_batch as f64 * expected_log_tau
779                - 0.5 * n_batch as f64 * (2.0 * PI).ln()
780                - 0.5 * expected_tau * sse);
781
782        // Prior term for beta
783        let beta_sq_sum = beta_sample.dot(&beta_sample);
784        let prior_beta =
785            -0.5 * d as f64 * (2.0 * PI * self.prior_var).ln() - 0.5 / self.prior_var * beta_sq_sum;
786
787        // Prior term for tau
788        let prior_tau = self.prior_shape * self.prior_rate.ln() - lgamma(self.prior_shape)
789            + (self.prior_shape - 1.0) * expected_log_tau
790            - self.prior_rate * expected_tau;
791
792        // Entropy of q(beta)
793        let entropy_beta: f64 = (0..d)
794            .map(|i| 0.5 * (1.0 + (2.0 * PI).ln()) + self.log_std_beta[i])
795            .sum();
796
797        // Entropy of q(tau) (Gamma distribution)
798        let entropy_tau = self.shape_tau - self.rate_tau.ln()
799            + lgamma(self.shape_tau)
800            + (1.0 - self.shape_tau) * digamma(self.shape_tau);
801
802        let elbo = likelihood + prior_beta + prior_tau + entropy_beta + entropy_tau;
803
804        // Compute gradients
805        let mut grad = Array1::zeros(n_params);
806
807        // Gradient wrt mean_beta
808        let grad_beta_from_likelihood = x_batch.t().dot(&residuals) * expected_tau * scale_factor;
809        let grad_beta_from_prior = &beta_sample * (-1.0 / self.prior_var);
810
811        for i in 0..d {
812            grad[i] = grad_beta_from_likelihood[i] + grad_beta_from_prior[i];
813        }
814
815        // Gradient wrt log_std_beta (through reparameterization)
816        for i in 0..d {
817            let dl_dbeta = grad_beta_from_likelihood[i] + grad_beta_from_prior[i];
818            // Chain rule: d/d(log_sigma) = d/dbeta * dbeta/d(log_sigma)
819            // dbeta/d(log_sigma) = epsilon * sigma (since beta = mu + sigma*epsilon)
820            grad[d + i] = dl_dbeta * epsilon[i] * std_beta[i] + 1.0; // +1 from entropy
821        }
822
823        // Gradient wrt shape_tau and rate_tau (use simpler gradient ascent on these)
824        // d ELBO / d shape_tau
825        let d_likelihood_shape =
826            scale_factor * 0.5 * n_batch as f64 * super::trigamma(self.shape_tau);
827        let d_prior_shape = (self.prior_shape - 1.0) * super::trigamma(self.shape_tau)
828            - self.prior_rate / self.rate_tau;
829        let d_entropy_shape = 1.0 - (1.0 - self.shape_tau) * super::trigamma(self.shape_tau)
830            + digamma(self.shape_tau) * (-1.0)
831            + super::trigamma(self.shape_tau) * (1.0 - self.shape_tau);
832        // Simplified: just compute numerically stable gradient
833        grad[2 * d] = d_likelihood_shape + d_prior_shape + d_entropy_shape * 0.01;
834
835        // d ELBO / d rate_tau
836        let d_likelihood_rate =
837            -scale_factor * 0.5 * sse * self.shape_tau / (self.rate_tau * self.rate_tau);
838        let d_prior_rate = self.prior_rate * self.shape_tau / (self.rate_tau * self.rate_tau);
839        grad[2 * d + 1] = d_likelihood_rate - d_prior_rate + 1.0 / self.rate_tau;
840
841        Ok((elbo, grad))
842    }
843
844    fn get_params(&self) -> Array1<f64> {
845        let d = self.n_features;
846        let mut params = Array1::zeros(2 * d + 2);
847        for i in 0..d {
848            params[i] = self.mean_beta[i];
849            params[d + i] = self.log_std_beta[i];
850        }
851        params[2 * d] = self.shape_tau;
852        params[2 * d + 1] = self.rate_tau;
853        params
854    }
855
856    fn set_params(&mut self, params: &Array1<f64>) -> Result<()> {
857        let d = self.n_features;
858        if params.len() != 2 * d + 2 {
859            return Err(StatsError::DimensionMismatch(format!(
860                "params length ({}) must be {}",
861                params.len(),
862                2 * d + 2
863            )));
864        }
865        for i in 0..d {
866            self.mean_beta[i] = params[i];
867            self.log_std_beta[i] = params[d + i];
868        }
869        // Ensure shape and rate stay positive
870        self.shape_tau = params[2 * d].max(1e-6);
871        self.rate_tau = params[2 * d + 1].max(1e-6);
872        Ok(())
873    }
874}
875
876/// Results from SVI Bayesian regression
877#[derive(Debug, Clone)]
878pub struct SviRegressionResult {
879    /// Posterior mean of coefficients
880    pub mean_beta: Array1<f64>,
881    /// Posterior standard deviation of coefficients
882    pub std_beta: Array1<f64>,
883    /// Posterior shape parameter for noise precision
884    pub shape_tau: f64,
885    /// Posterior rate parameter for noise precision
886    pub rate_tau: f64,
887    /// Optimization diagnostics
888    pub diagnostics: VariationalDiagnostics,
889}
890
891impl SviRegressionResult {
892    /// Get expected noise variance: E[1/tau] = rate / (shape - 1) for shape > 1
893    pub fn expected_noise_variance(&self) -> f64 {
894        if self.shape_tau > 1.0 {
895            self.rate_tau / (self.shape_tau - 1.0)
896        } else {
897            self.rate_tau / self.shape_tau
898        }
899    }
900
901    /// Compute credible intervals for coefficients
902    pub fn credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
903        check_probability(confidence, "confidence")?;
904        let alpha = (1.0 - confidence) / 2.0;
905        let z_critical = super::normal_ppf(1.0 - alpha)?;
906
907        let d = self.mean_beta.len();
908        let mut intervals = Array2::zeros((d, 2));
909        for i in 0..d {
910            intervals[[i, 0]] = self.mean_beta[i] - z_critical * self.std_beta[i];
911            intervals[[i, 1]] = self.mean_beta[i] + z_critical * self.std_beta[i];
912        }
913        Ok(intervals)
914    }
915}
916
917// ============================================================================
918// Helper functions
919// ============================================================================
920
921/// Generate approximate standard normal samples using Box-Muller-like
922/// deterministic scheme (based on a seed for reproducibility)
923///
924/// For production use, one should use a proper PRNG. This provides
925/// a deterministic but reasonable approximation for testing and
926/// demonstration purposes.
927fn generate_standard_normal(dim: usize, seed: u64) -> Array1<f64> {
928    let mut result = Array1::zeros(dim);
929    let golden_ratio = 1.618033988749895;
930
931    for i in 0..dim {
932        // Use a quasi-random sequence based on golden ratio
933        let u1 = ((seed as f64 * golden_ratio + i as f64 * 0.7548776662466927) % 1.0).abs();
934        let u2 = ((seed as f64 * 0.5698402909980532 + i as f64 * golden_ratio) % 1.0).abs();
935
936        // Clamp to avoid log(0)
937        let u1_safe = u1.max(1e-10).min(1.0 - 1e-10);
938        let u2_safe = u2.max(1e-10).min(1.0 - 1e-10);
939
940        // Box-Muller transform
941        let r = (-2.0 * u1_safe.ln()).sqrt();
942        let theta = 2.0 * PI * u2_safe;
943        result[i] = r * theta.cos();
944    }
945
946    result
947}
948
949// ============================================================================
950// Tests
951// ============================================================================
952
953#[cfg(test)]
954mod tests {
955    use super::*;
956    use scirs2_core::ndarray::Array2;
957
958    #[test]
959    fn test_learning_rate_constant() {
960        let lr = LearningRateSchedule::Constant { lr: 0.01 };
961        assert!((lr.get_lr(0) - 0.01).abs() < 1e-10);
962        assert!((lr.get_lr(100) - 0.01).abs() < 1e-10);
963    }
964
965    #[test]
966    fn test_learning_rate_robbins_monro() {
967        let lr = LearningRateSchedule::RobbinsMonro {
968            lr0: 0.1,
969            decay: 0.01,
970        };
971        assert!((lr.get_lr(0) - 0.1).abs() < 1e-10);
972        assert!(lr.get_lr(100) < lr.get_lr(0));
973        assert!(lr.get_lr(100) > 0.0);
974    }
975
976    #[test]
977    fn test_learning_rate_exponential() {
978        let lr = LearningRateSchedule::ExponentialDecay {
979            lr0: 0.1,
980            gamma: 0.99,
981        };
982        assert!((lr.get_lr(0) - 0.1).abs() < 1e-10);
983        assert!(lr.get_lr(100) < lr.get_lr(0));
984    }
985
986    #[test]
987    fn test_adam_state() {
988        let mut adam = AdamState::new(3, 0.01, 0.9, 0.999, 1e-8).expect("should create");
989        let grad = Array1::from_vec(vec![1.0, -0.5, 0.3]);
990        let update = adam.compute_update(&grad).expect("should compute update");
991        assert_eq!(update.len(), 3);
992        // First step should be approximately lr * grad / (sqrt(grad^2) + eps) = lr * sign(grad)
993        for i in 0..3 {
994            assert!(update[i].is_finite());
995        }
996    }
997
998    #[test]
999    fn test_natural_gradient_roundtrip() {
1000        let mf = MeanFieldGaussian::from_params(
1001            Array1::from_vec(vec![1.0, 2.0]),
1002            Array1::from_vec(vec![0.5, -0.3]),
1003        )
1004        .expect("should create");
1005
1006        let nat = NaturalGradientParams::from_mean_field(&mf);
1007        let recovered = nat.to_mean_field().expect("should convert back");
1008
1009        for i in 0..2 {
1010            assert!(
1011                (recovered.means[i] - mf.means[i]).abs() < 1e-6,
1012                "means differ at {}: {} vs {}",
1013                i,
1014                recovered.means[i],
1015                mf.means[i]
1016            );
1017            assert!(
1018                (recovered.log_stds[i] - mf.log_stds[i]).abs() < 1e-6,
1019                "log_stds differ at {}: {} vs {}",
1020                i,
1021                recovered.log_stds[i],
1022                mf.log_stds[i]
1023            );
1024        }
1025    }
1026
1027    #[test]
1028    fn test_svi_creation() {
1029        let config = SviConfig {
1030            max_iter: 100,
1031            batch_size: 10,
1032            ..SviConfig::default()
1033        };
1034        let svi = StochasticVI::new(5, config).expect("should create SVI");
1035        assert_eq!(svi.variational.dim, 5);
1036    }
1037
1038    #[test]
1039    fn test_svi_bayesian_regression() {
1040        // Simple test: y = x + noise
1041        let n = 100;
1042        let x_data: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
1043        let y_data: Vec<f64> = x_data
1044            .iter()
1045            .enumerate()
1046            .map(|(i, &xi)| xi + 0.1 * ((i * 7 % 13) as f64 - 6.0) / 6.0)
1047            .collect();
1048
1049        let x = Array2::from_shape_fn((n, 1), |(i, _)| x_data[i]);
1050        let y = Array1::from_vec(y_data);
1051
1052        let config = SviConfig {
1053            max_iter: 200,
1054            batch_size: 20,
1055            lr_schedule: LearningRateSchedule::Adam {
1056                lr: 0.01,
1057                beta1: 0.9,
1058                beta2: 0.999,
1059                epsilon: 1e-8,
1060            },
1061            ..SviConfig::default()
1062        };
1063
1064        let mut model = SviBayesianRegression::new(1, config).expect("should create");
1065        let result = model.fit(x.view(), y.view()).expect("should fit");
1066
1067        // Check that we get finite results
1068        assert!(result.mean_beta[0].is_finite());
1069        assert!(result.std_beta[0].is_finite());
1070        assert!(result.diagnostics.n_iterations > 0);
1071    }
1072
1073    #[test]
1074    fn test_generate_standard_normal() {
1075        let samples = generate_standard_normal(100, 42);
1076        assert_eq!(samples.len(), 100);
1077        // All should be finite
1078        for &s in samples.iter() {
1079            assert!(s.is_finite(), "sample should be finite, got {}", s);
1080        }
1081        // Mean should be roughly zero (within reasonable bounds for quasi-random)
1082        let mean = samples.sum() / 100.0;
1083        assert!(
1084            mean.abs() < 2.0,
1085            "mean should be roughly zero, got {}",
1086            mean
1087        );
1088    }
1089
1090    #[test]
1091    fn test_svi_config_default() {
1092        let config = SviConfig::default();
1093        assert_eq!(config.max_iter, 1000);
1094        assert_eq!(config.batch_size, 32);
1095        assert!(config.grad_clip > 0.0);
1096    }
1097}