Skip to main content

scirs2_stats/bayesian/variational/
mod.rs

1//! Variational inference for Bayesian models
2//!
3//! This module implements variational inference methods as alternatives to MCMC
4//! for approximate Bayesian inference. It provides:
5//!
6//! - **Coordinate Ascent VI (CAVI)**: Classical variational inference for Bayesian
7//!   linear regression and Automatic Relevance Determination (ARD)
8//! - **Stochastic Variational Inference (SVI)**: Scalable VI with mini-batch
9//!   ELBO estimation, natural gradient updates, and learning rate scheduling
10//! - **ADVI**: Automatic Differentiation Variational Inference with mean-field
11//!   and full-rank Gaussian approximations
12//! - **Variational Families**: Mean-field Gaussian, full-rank Gaussian, and
13//!   normalizing flow placeholders
14//! - **Diagnostics**: ELBO trace, gradient norm monitoring, convergence checks
15
16mod advi;
17mod families;
18mod svi;
19
20pub use advi::*;
21pub use families::*;
22pub use svi::*;
23
24use crate::error::{StatsError, StatsResult as Result};
25use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
26use scirs2_core::validation::*;
27use statrs::statistics::Statistics;
28use std::f64::consts::PI;
29
30// ============================================================================
31// Variational Diagnostics
32// ============================================================================
33
34/// Diagnostics for monitoring variational inference convergence
35#[derive(Debug, Clone)]
36pub struct VariationalDiagnostics {
37    /// ELBO values at each iteration
38    pub elbo_trace: Vec<f64>,
39    /// Gradient norms at each iteration (if tracked)
40    pub gradient_norms: Vec<f64>,
41    /// Parameter change norms at each iteration (if tracked)
42    pub param_change_norms: Vec<f64>,
43    /// Whether the algorithm converged
44    pub converged: bool,
45    /// Number of iterations performed
46    pub n_iterations: usize,
47    /// Final ELBO value
48    pub final_elbo: f64,
49}
50
51impl VariationalDiagnostics {
52    /// Create new diagnostics tracker
53    pub fn new() -> Self {
54        Self {
55            elbo_trace: Vec::new(),
56            gradient_norms: Vec::new(),
57            param_change_norms: Vec::new(),
58            converged: false,
59            n_iterations: 0,
60            final_elbo: f64::NEG_INFINITY,
61        }
62    }
63
64    /// Record an ELBO value
65    pub fn record_elbo(&mut self, elbo: f64) {
66        self.elbo_trace.push(elbo);
67        self.final_elbo = elbo;
68        self.n_iterations = self.elbo_trace.len();
69    }
70
71    /// Record a gradient norm
72    pub fn record_gradient_norm(&mut self, norm: f64) {
73        self.gradient_norms.push(norm);
74    }
75
76    /// Record a parameter change norm
77    pub fn record_param_change(&mut self, norm: f64) {
78        self.param_change_norms.push(norm);
79    }
80
81    /// Check whether parameters have converged based on ELBO change
82    pub fn check_elbo_convergence(&self, tol: f64) -> bool {
83        if self.elbo_trace.len() < 2 {
84            return false;
85        }
86        let n = self.elbo_trace.len();
87        (self.elbo_trace[n - 1] - self.elbo_trace[n - 2]).abs() < tol
88    }
89
90    /// Check whether gradient norms indicate convergence
91    pub fn check_gradient_convergence(&self, tol: f64) -> bool {
92        if let Some(&last_norm) = self.gradient_norms.last() {
93            last_norm < tol
94        } else {
95            false
96        }
97    }
98
99    /// Check whether parameter changes indicate convergence
100    pub fn check_param_convergence(&self, tol: f64) -> bool {
101        if let Some(&last_change) = self.param_change_norms.last() {
102            last_change < tol
103        } else {
104            false
105        }
106    }
107
108    /// Compute the relative ELBO change over recent iterations
109    pub fn relative_elbo_change(&self, window: usize) -> Option<f64> {
110        let n = self.elbo_trace.len();
111        if n < window + 1 {
112            return None;
113        }
114        let recent = self.elbo_trace[n - 1];
115        let earlier = self.elbo_trace[n - 1 - window];
116        if earlier.abs() < 1e-15 {
117            return Some(f64::INFINITY);
118        }
119        Some((recent - earlier).abs() / earlier.abs())
120    }
121
122    /// Get summary statistics for the ELBO trace
123    pub fn elbo_summary(&self) -> ElboSummary {
124        let n = self.elbo_trace.len();
125        if n == 0 {
126            return ElboSummary {
127                min: f64::NAN,
128                max: f64::NAN,
129                final_value: f64::NAN,
130                mean_change: f64::NAN,
131                monotonic: true,
132            };
133        }
134
135        let min = self
136            .elbo_trace
137            .iter()
138            .copied()
139            .fold(f64::INFINITY, f64::min);
140        let max = self
141            .elbo_trace
142            .iter()
143            .copied()
144            .fold(f64::NEG_INFINITY, f64::max);
145
146        let mut monotonic = true;
147        let mut total_change = 0.0;
148        for i in 1..n {
149            let change = self.elbo_trace[i] - self.elbo_trace[i - 1];
150            total_change += change.abs();
151            if change < -1e-10 {
152                monotonic = false;
153            }
154        }
155
156        let mean_change = if n > 1 {
157            total_change / (n - 1) as f64
158        } else {
159            0.0
160        };
161
162        ElboSummary {
163            min,
164            max,
165            final_value: self.elbo_trace[n - 1],
166            mean_change,
167            monotonic,
168        }
169    }
170}
171
172impl Default for VariationalDiagnostics {
173    fn default() -> Self {
174        Self::new()
175    }
176}
177
178/// Summary statistics for the ELBO trace
179#[derive(Debug, Clone)]
180pub struct ElboSummary {
181    /// Minimum ELBO observed
182    pub min: f64,
183    /// Maximum ELBO observed
184    pub max: f64,
185    /// Final ELBO value
186    pub final_value: f64,
187    /// Mean absolute change between consecutive iterations
188    pub mean_change: f64,
189    /// Whether the ELBO was monotonically increasing
190    pub monotonic: bool,
191}
192
193// ============================================================================
194// Variational Families
195// ============================================================================
196
197/// Mean-field Gaussian variational family
198///
199/// Approximates the posterior with a factorized Gaussian:
200/// q(z) = prod_i N(z_i; mu_i, sigma_i^2)
201///
202/// This is the simplest variational family but cannot capture correlations.
203#[derive(Debug, Clone)]
204pub struct MeanFieldGaussian {
205    /// Variational means
206    pub means: Array1<f64>,
207    /// Variational log standard deviations (unconstrained parameterization)
208    pub log_stds: Array1<f64>,
209    /// Dimensionality
210    pub dim: usize,
211}
212
213impl MeanFieldGaussian {
214    /// Create a new mean-field Gaussian with given dimension
215    pub fn new(dim: usize) -> Result<Self> {
216        check_positive(dim, "dim")?;
217        Ok(Self {
218            means: Array1::zeros(dim),
219            log_stds: Array1::zeros(dim), // std = 1.0 initially
220            dim,
221        })
222    }
223
224    /// Create from specific parameters
225    pub fn from_params(means: Array1<f64>, log_stds: Array1<f64>) -> Result<Self> {
226        if means.len() != log_stds.len() {
227            return Err(StatsError::DimensionMismatch(format!(
228                "means length ({}) must match log_stds length ({})",
229                means.len(),
230                log_stds.len()
231            )));
232        }
233        checkarray_finite(&means, "means")?;
234        checkarray_finite(&log_stds, "log_stds")?;
235        let dim = means.len();
236        Ok(Self {
237            means,
238            log_stds,
239            dim,
240        })
241    }
242
243    /// Get the standard deviations
244    pub fn stds(&self) -> Array1<f64> {
245        self.log_stds.mapv(f64::exp)
246    }
247
248    /// Get the variances
249    pub fn variances(&self) -> Array1<f64> {
250        self.log_stds.mapv(|ls| (2.0 * ls).exp())
251    }
252
253    /// Sample from the variational distribution using reparameterization trick
254    ///
255    /// z = mu + sigma * epsilon, where epsilon ~ N(0, I)
256    pub fn sample(&self, epsilon: &Array1<f64>) -> Result<Array1<f64>> {
257        if epsilon.len() != self.dim {
258            return Err(StatsError::DimensionMismatch(format!(
259                "epsilon length ({}) must match dimension ({})",
260                epsilon.len(),
261                self.dim
262            )));
263        }
264        let stds = self.stds();
265        Ok(&self.means + &(&stds * epsilon))
266    }
267
268    /// Compute the entropy of the mean-field Gaussian
269    /// H\[q\] = sum_i 0.5 * (1 + log(2*pi) + 2*log_std_i)
270    pub fn entropy(&self) -> f64 {
271        let base = 0.5 * (1.0 + (2.0 * PI).ln());
272        self.log_stds.iter().map(|&ls| base + ls).sum::<f64>()
273    }
274
275    /// Compute log q(z) for a given z
276    pub fn log_prob(&self, z: &Array1<f64>) -> Result<f64> {
277        if z.len() != self.dim {
278            return Err(StatsError::DimensionMismatch(format!(
279                "z length ({}) must match dimension ({})",
280                z.len(),
281                self.dim
282            )));
283        }
284        let stds = self.stds();
285        let mut log_prob = 0.0;
286        for i in 0..self.dim {
287            let diff = z[i] - self.means[i];
288            log_prob += -0.5 * (2.0 * PI).ln() - self.log_stds[i] - 0.5 * (diff / stds[i]).powi(2);
289        }
290        Ok(log_prob)
291    }
292
293    /// Get total number of variational parameters
294    pub fn n_params(&self) -> usize {
295        2 * self.dim
296    }
297
298    /// Get all variational parameters as a flat vector
299    pub fn get_params(&self) -> Array1<f64> {
300        let mut params = Array1::zeros(2 * self.dim);
301        for i in 0..self.dim {
302            params[i] = self.means[i];
303            params[self.dim + i] = self.log_stds[i];
304        }
305        params
306    }
307
308    /// Set variational parameters from a flat vector
309    pub fn set_params(&mut self, params: &Array1<f64>) -> Result<()> {
310        if params.len() != 2 * self.dim {
311            return Err(StatsError::DimensionMismatch(format!(
312                "params length ({}) must be 2 * dim ({})",
313                params.len(),
314                2 * self.dim
315            )));
316        }
317        for i in 0..self.dim {
318            self.means[i] = params[i];
319            self.log_stds[i] = params[self.dim + i];
320        }
321        Ok(())
322    }
323}
324
325/// Full-rank Gaussian variational family
326///
327/// Approximates the posterior with a Gaussian with full covariance:
328/// q(z) = N(z; mu, Sigma) where Sigma = L L^T (Cholesky parameterization)
329///
330/// This can capture correlations but has O(d^2) parameters.
331#[derive(Debug, Clone)]
332pub struct FullRankGaussian {
333    /// Variational mean
334    pub mean: Array1<f64>,
335    /// Lower triangular Cholesky factor of the covariance
336    /// Stored as a flattened lower-triangular matrix
337    pub chol_factor: Array2<f64>,
338    /// Dimensionality
339    pub dim: usize,
340}
341
342impl FullRankGaussian {
343    /// Create a new full-rank Gaussian with given dimension
344    pub fn new(dim: usize) -> Result<Self> {
345        check_positive(dim, "dim")?;
346        Ok(Self {
347            mean: Array1::zeros(dim),
348            chol_factor: Array2::eye(dim), // Identity = unit covariance
349            dim,
350        })
351    }
352
353    /// Create from specific parameters
354    pub fn from_params(mean: Array1<f64>, chol_factor: Array2<f64>) -> Result<Self> {
355        let dim = mean.len();
356        if chol_factor.nrows() != dim || chol_factor.ncols() != dim {
357            return Err(StatsError::DimensionMismatch(format!(
358                "chol_factor shape ({},{}) must be ({},{})",
359                chol_factor.nrows(),
360                chol_factor.ncols(),
361                dim,
362                dim
363            )));
364        }
365        checkarray_finite(&mean, "mean")?;
366        checkarray_finite(&chol_factor, "chol_factor")?;
367        Ok(Self {
368            mean,
369            chol_factor,
370            dim,
371        })
372    }
373
374    /// Get the covariance matrix: Sigma = L * L^T
375    pub fn covariance(&self) -> Array2<f64> {
376        self.chol_factor.dot(&self.chol_factor.t())
377    }
378
379    /// Get the precision matrix (inverse covariance)
380    pub fn precision(&self) -> Result<Array2<f64>> {
381        let cov = self.covariance();
382        scirs2_linalg::inv(&cov.view(), None).map_err(|e| {
383            StatsError::ComputationError(format!("Failed to invert covariance: {}", e))
384        })
385    }
386
387    /// Sample from the variational distribution using reparameterization trick
388    ///
389    /// z = mu + L * epsilon, where epsilon ~ N(0, I)
390    pub fn sample(&self, epsilon: &Array1<f64>) -> Result<Array1<f64>> {
391        if epsilon.len() != self.dim {
392            return Err(StatsError::DimensionMismatch(format!(
393                "epsilon length ({}) must match dimension ({})",
394                epsilon.len(),
395                self.dim
396            )));
397        }
398        Ok(&self.mean + &self.chol_factor.dot(epsilon))
399    }
400
401    /// Compute the entropy of the full-rank Gaussian
402    /// H\[q\] = 0.5 * d * (1 + log(2*pi)) + sum_i log(L_ii)
403    pub fn entropy(&self) -> f64 {
404        let base = 0.5 * self.dim as f64 * (1.0 + (2.0 * PI).ln());
405        let log_det: f64 = (0..self.dim)
406            .map(|i| self.chol_factor[[i, i]].abs().ln())
407            .sum();
408        base + log_det
409    }
410
411    /// Compute log q(z) for a given z
412    pub fn log_prob(&self, z: &Array1<f64>) -> Result<f64> {
413        if z.len() != self.dim {
414            return Err(StatsError::DimensionMismatch(format!(
415                "z length ({}) must match dimension ({})",
416                z.len(),
417                self.dim
418            )));
419        }
420        let precision = self.precision()?;
421        let diff = z - &self.mean;
422        let mahal = diff.dot(&precision.dot(&diff));
423        let log_det: f64 = (0..self.dim)
424            .map(|i| self.chol_factor[[i, i]].abs().ln())
425            .sum();
426        Ok(-0.5 * self.dim as f64 * (2.0 * PI).ln() - log_det - 0.5 * mahal)
427    }
428
429    /// Get total number of variational parameters
430    pub fn n_params(&self) -> usize {
431        self.dim + self.dim * (self.dim + 1) / 2
432    }
433
434    /// Get all variational parameters as a flat vector
435    /// Layout: [mean, lower-triangular elements of L]
436    pub fn get_params(&self) -> Array1<f64> {
437        let n_tril = self.dim * (self.dim + 1) / 2;
438        let mut params = Array1::zeros(self.dim + n_tril);
439        // Mean
440        for i in 0..self.dim {
441            params[i] = self.mean[i];
442        }
443        // Lower triangular
444        let mut idx = self.dim;
445        for i in 0..self.dim {
446            for j in 0..=i {
447                params[idx] = self.chol_factor[[i, j]];
448                idx += 1;
449            }
450        }
451        params
452    }
453
454    /// Set variational parameters from a flat vector
455    pub fn set_params(&mut self, params: &Array1<f64>) -> Result<()> {
456        let n_tril = self.dim * (self.dim + 1) / 2;
457        let expected = self.dim + n_tril;
458        if params.len() != expected {
459            return Err(StatsError::DimensionMismatch(format!(
460                "params length ({}) must be {}",
461                params.len(),
462                expected
463            )));
464        }
465        // Mean
466        for i in 0..self.dim {
467            self.mean[i] = params[i];
468        }
469        // Lower triangular
470        let mut idx = self.dim;
471        self.chol_factor = Array2::zeros((self.dim, self.dim));
472        for i in 0..self.dim {
473            for j in 0..=i {
474                self.chol_factor[[i, j]] = params[idx];
475                idx += 1;
476            }
477        }
478        Ok(())
479    }
480}
481
482/// Normalizing flow variational family (placeholder/scaffold)
483///
484/// This provides a framework for flow-based variational inference where
485/// the posterior is represented as a composition of invertible transformations
486/// applied to a base distribution (typically a standard Gaussian).
487///
488/// q(z) = q_0(f^{-1}(z)) * |det(df^{-1}/dz)|
489///
490/// Currently supports:
491/// - Planar flows: f(z) = z + u * h(w^T z + b)
492/// - Radial flows: f(z) = z + beta * h(alpha, r)(z - z0)
493#[derive(Debug, Clone)]
494pub struct NormalizingFlowVI {
495    /// Base distribution (mean-field Gaussian)
496    pub base: MeanFieldGaussian,
497    /// Flow layers
498    pub flows: Vec<FlowLayer>,
499    /// Dimensionality
500    pub dim: usize,
501}
502
503/// A single flow layer (invertible transformation)
504#[derive(Debug, Clone)]
505pub enum FlowLayer {
506    /// Planar flow: f(z) = z + u * tanh(w^T z + b)
507    Planar {
508        /// Direction of perturbation
509        u: Array1<f64>,
510        /// Weight vector
511        w: Array1<f64>,
512        /// Bias
513        b: f64,
514    },
515    /// Radial flow: f(z) = z + beta * h(alpha, r)(z - z0)
516    Radial {
517        /// Center point
518        z0: Array1<f64>,
519        /// Scale parameter (log-parameterized for positivity)
520        log_alpha: f64,
521        /// Contraction/expansion parameter
522        beta: f64,
523    },
524}
525
526impl NormalizingFlowVI {
527    /// Create a new normalizing flow VI with a base mean-field Gaussian
528    pub fn new(dim: usize, n_flows: usize) -> Result<Self> {
529        check_positive(dim, "dim")?;
530        let base = MeanFieldGaussian::new(dim)?;
531
532        // Initialize with identity-like planar flows
533        let mut flows = Vec::with_capacity(n_flows);
534        for _ in 0..n_flows {
535            let u = Array1::from_elem(dim, 0.01);
536            let w = Array1::from_elem(dim, 0.01);
537            flows.push(FlowLayer::Planar { u, w, b: 0.0 });
538        }
539
540        Ok(Self { base, flows, dim })
541    }
542
543    /// Add a planar flow layer
544    pub fn add_planar_flow(&mut self, u: Array1<f64>, w: Array1<f64>, b: f64) -> Result<()> {
545        if u.len() != self.dim || w.len() != self.dim {
546            return Err(StatsError::DimensionMismatch(format!(
547                "u ({}) and w ({}) must have dimension {}",
548                u.len(),
549                w.len(),
550                self.dim
551            )));
552        }
553        self.flows.push(FlowLayer::Planar { u, w, b });
554        Ok(())
555    }
556
557    /// Add a radial flow layer
558    pub fn add_radial_flow(&mut self, z0: Array1<f64>, log_alpha: f64, beta: f64) -> Result<()> {
559        if z0.len() != self.dim {
560            return Err(StatsError::DimensionMismatch(format!(
561                "z0 ({}) must have dimension {}",
562                z0.len(),
563                self.dim
564            )));
565        }
566        self.flows.push(FlowLayer::Radial {
567            z0,
568            log_alpha,
569            beta,
570        });
571        Ok(())
572    }
573
574    /// Transform a sample through all flow layers, returning the transformed
575    /// sample and the sum of log-abs-det-Jacobians
576    pub fn transform(&self, z0: &Array1<f64>) -> Result<(Array1<f64>, f64)> {
577        if z0.len() != self.dim {
578            return Err(StatsError::DimensionMismatch(format!(
579                "z0 length ({}) must match dimension ({})",
580                z0.len(),
581                self.dim
582            )));
583        }
584        let mut z = z0.clone();
585        let mut sum_log_det_jac = 0.0;
586
587        for flow in &self.flows {
588            let (z_new, log_det) = apply_flow_layer(flow, &z)?;
589            z = z_new;
590            sum_log_det_jac += log_det;
591        }
592
593        Ok((z, sum_log_det_jac))
594    }
595
596    /// Sample from the flow-transformed distribution
597    pub fn sample(&self, epsilon: &Array1<f64>) -> Result<(Array1<f64>, f64)> {
598        let z0 = self.base.sample(epsilon)?;
599        let (z_k, sum_log_det) = self.transform(&z0)?;
600        let log_q0 = self.base.log_prob(&z0)?;
601        // log q_K(z_K) = log q_0(z_0) - sum log|det J_k|
602        let log_q_k = log_q0 - sum_log_det;
603        Ok((z_k, log_q_k))
604    }
605
606    /// Get the number of flow parameters
607    pub fn n_flow_params(&self) -> usize {
608        self.flows
609            .iter()
610            .map(|f| match f {
611                FlowLayer::Planar { u, w, .. } => u.len() + w.len() + 1,
612                FlowLayer::Radial { z0, .. } => z0.len() + 2,
613            })
614            .sum()
615    }
616}
617
618/// Apply a single flow layer to a point z
619fn apply_flow_layer(flow: &FlowLayer, z: &Array1<f64>) -> Result<(Array1<f64>, f64)> {
620    match flow {
621        FlowLayer::Planar { u, w, b } => {
622            // f(z) = z + u * tanh(w^T z + b)
623            let activation = w.dot(z) + b;
624            let tanh_val = activation.tanh();
625            let z_new = z + &(u * tanh_val);
626
627            // log|det J| = log|1 + u^T * h'(w^T z + b) * w|
628            let dtanh = 1.0 - tanh_val * tanh_val;
629            let psi = w * dtanh;
630            let det = 1.0 + u.dot(&psi);
631            let log_det = det.abs().ln();
632
633            Ok((z_new, log_det))
634        }
635        FlowLayer::Radial {
636            z0,
637            log_alpha,
638            beta,
639        } => {
640            let alpha = log_alpha.exp();
641            let diff = z - z0;
642            let r = diff.dot(&diff).sqrt();
643            let h = 1.0 / (alpha + r);
644            let z_new = z + &(&diff * (*beta * h));
645
646            // log|det J| for radial flow
647            let d = z.len() as f64;
648            let h_prime = -1.0 / ((alpha + r) * (alpha + r));
649            let term1 = (1.0 + beta * h).powi(d as i32 - 1);
650            let term2 = 1.0 + beta * h + beta * h_prime * r;
651            let det = term1 * term2;
652            let log_det = det.abs().ln();
653
654            Ok((z_new, log_det))
655        }
656    }
657}
658
659// ============================================================================
660// Bayesian Linear Regression (existing)
661// ============================================================================
662
663/// Mean-field variational inference for Bayesian linear regression
664///
665/// Approximates the posterior with a factorized normal distribution:
666/// q(beta, tau) = q(beta)q(tau) where q(beta) ~ N(mu_beta, Sigma_beta) and q(tau) ~ Gamma(a_tau, b_tau)
667#[derive(Debug, Clone)]
668pub struct VariationalBayesianRegression {
669    /// Variational mean for coefficients
670    pub mean_beta: Array1<f64>,
671    /// Variational covariance for coefficients
672    pub cov_beta: Array2<f64>,
673    /// Variational shape parameter for precision
674    pub shape_tau: f64,
675    /// Variational rate parameter for precision
676    pub rate_tau: f64,
677    /// Prior parameters
678    pub prior_mean_beta: Array1<f64>,
679    pub prior_cov_beta: Array2<f64>,
680    pub priorshape_tau: f64,
681    pub prior_rate_tau: f64,
682    /// Model dimensionality
683    pub n_features: usize,
684    /// Whether to fit intercept
685    pub fit_intercept: bool,
686}
687
688impl VariationalBayesianRegression {
689    /// Create a new variational Bayesian regression model
690    pub fn new(n_features: usize, fit_intercept: bool) -> Result<Self> {
691        check_positive(n_features, "n_features")?;
692
693        // Initialize with weakly informative priors
694        let prior_mean_beta = Array1::zeros(n_features);
695        let prior_cov_beta = Array2::eye(n_features) * 100.0; // Large variance = weak prior
696        let priorshape_tau = 1e-3;
697        let prior_rate_tau = 1e-3;
698
699        Ok(Self {
700            mean_beta: prior_mean_beta.clone(),
701            cov_beta: prior_cov_beta.clone(),
702            shape_tau: priorshape_tau,
703            rate_tau: prior_rate_tau,
704            prior_mean_beta,
705            prior_cov_beta,
706            priorshape_tau,
707            prior_rate_tau,
708            n_features,
709            fit_intercept,
710        })
711    }
712
713    /// Set custom priors
714    pub fn with_priors(
715        mut self,
716        prior_mean_beta: Array1<f64>,
717        prior_cov_beta: Array2<f64>,
718        priorshape_tau: f64,
719        prior_rate_tau: f64,
720    ) -> Result<Self> {
721        checkarray_finite(&prior_mean_beta, "prior_mean_beta")?;
722        checkarray_finite(&prior_cov_beta, "prior_cov_beta")?;
723        check_positive(priorshape_tau, "priorshape_tau")?;
724        check_positive(prior_rate_tau, "prior_rate_tau")?;
725
726        self.prior_mean_beta = prior_mean_beta.clone();
727        self.prior_cov_beta = prior_cov_beta.clone();
728        self.priorshape_tau = priorshape_tau;
729        self.prior_rate_tau = prior_rate_tau;
730        self.mean_beta = prior_mean_beta;
731        self.cov_beta = prior_cov_beta;
732        self.shape_tau = priorshape_tau;
733        self.rate_tau = prior_rate_tau;
734
735        Ok(self)
736    }
737
738    /// Fit the model using coordinate ascent variational inference
739    pub fn fit(
740        &mut self,
741        x: ArrayView2<f64>,
742        y: ArrayView1<f64>,
743        max_iter: usize,
744        tol: f64,
745    ) -> Result<VariationalRegressionResult> {
746        checkarray_finite(&x, "x")?;
747        checkarray_finite(&y, "y")?;
748        check_positive(max_iter, "max_iter")?;
749        check_positive(tol, "tol")?;
750
751        let (n_samples_, n_features) = x.dim();
752        if y.len() != n_samples_ {
753            return Err(StatsError::DimensionMismatch(format!(
754                "y length ({}) must match x rows ({})",
755                y.len(),
756                n_samples_
757            )));
758        }
759
760        if n_features != self.n_features {
761            return Err(StatsError::DimensionMismatch(format!(
762                "x features ({}) must match model features ({})",
763                n_features, self.n_features
764            )));
765        }
766
767        // Center data if fitting intercept
768        let (x_centered, y_centered, x_mean, y_mean) = if self.fit_intercept {
769            let x_mean = x.mean_axis(Axis(0)).expect("Operation failed");
770            let y_mean = y.mean();
771
772            let mut x_centered = x.to_owned();
773            for mut row in x_centered.rows_mut() {
774                row -= &x_mean;
775            }
776            let y_centered = &y.to_owned() - y_mean;
777
778            (x_centered, y_centered, Some(x_mean), Some(y_mean))
779        } else {
780            (x.to_owned(), y.to_owned(), None, None)
781        };
782
783        // Precompute matrices
784        let xtx = x_centered.t().dot(&x_centered);
785        let xty = x_centered.t().dot(&y_centered);
786        let yty = y_centered.dot(&y_centered);
787
788        // Prior precision matrix
789        let prior_precision =
790            scirs2_linalg::inv(&self.prior_cov_beta.view(), None).map_err(|e| {
791                StatsError::ComputationError(format!("Failed to invert prior covariance: {}", e))
792            })?;
793
794        let mut prev_elbo = f64::NEG_INFINITY;
795        let mut elbo_history = Vec::new();
796
797        for _iter in 0..max_iter {
798            // Update q(beta)
799            self.update_beta_variational(&xtx, &xty, &prior_precision)?;
800
801            // Update q(tau)
802            self.update_tau_variational(n_samples_ as f64, &xtx, yty)?;
803
804            // Compute ELBO
805            let elbo = self.compute_elbo(n_samples_ as f64, &xtx, &xty, yty, &prior_precision)?;
806            elbo_history.push(elbo);
807
808            // Check convergence
809            if _iter > 0 && (elbo - prev_elbo).abs() < tol {
810                break;
811            }
812
813            prev_elbo = elbo;
814        }
815
816        Ok(VariationalRegressionResult {
817            mean_beta: self.mean_beta.clone(),
818            cov_beta: self.cov_beta.clone(),
819            shape_tau: self.shape_tau,
820            rate_tau: self.rate_tau,
821            elbo: prev_elbo,
822            elbo_history: elbo_history.clone(),
823            n_samples_,
824            n_features: self.n_features,
825            x_mean,
826            y_mean,
827            converged: elbo_history.len() < max_iter,
828        })
829    }
830
831    /// Update variational distribution for beta
832    fn update_beta_variational(
833        &mut self,
834        xtx: &Array2<f64>,
835        xty: &Array1<f64>,
836        prior_precision: &Array2<f64>,
837    ) -> Result<()> {
838        // Expected precision: E[tau] = shape / rate
839        let expected_tau = self.shape_tau / self.rate_tau;
840
841        // Posterior precision
842        let precision_beta = prior_precision + &(xtx * expected_tau);
843
844        // Posterior covariance
845        self.cov_beta = scirs2_linalg::inv(&precision_beta.view(), None).map_err(|e| {
846            StatsError::ComputationError(format!("Failed to invert precision: {}", e))
847        })?;
848
849        // Posterior mean
850        let prior_contrib = prior_precision.dot(&self.prior_mean_beta);
851        let data_contrib = xty * expected_tau;
852        self.mean_beta = self.cov_beta.dot(&(prior_contrib + data_contrib));
853
854        Ok(())
855    }
856
857    /// Update variational distribution for tau
858    fn update_tau_variational(
859        &mut self,
860        n_samples_: f64,
861        xtx: &Array2<f64>,
862        yty: f64,
863    ) -> Result<()> {
864        // Shape parameter
865        self.shape_tau = self.priorshape_tau + n_samples_ / 2.0;
866
867        // Rate parameter
868        let expected_beta_outer = &self.cov_beta + outer_product(&self.mean_beta);
869        let trace_term = (xtx * &expected_beta_outer).sum();
870        let quadratic_term = 2.0 * self.mean_beta.dot(&xtx.dot(&self.mean_beta));
871
872        self.rate_tau = self.prior_rate_tau + 0.5 * (yty - quadratic_term + trace_term);
873
874        Ok(())
875    }
876
877    /// Compute Evidence Lower BOund (ELBO)
878    fn compute_elbo(
879        &self,
880        n_samples_: f64,
881        xtx: &Array2<f64>,
882        xty: &Array1<f64>,
883        yty: f64,
884        prior_precision: &Array2<f64>,
885    ) -> Result<f64> {
886        let expected_tau = self.shape_tau / self.rate_tau;
887        let expected_log_tau = digamma(self.shape_tau) - self.rate_tau.ln();
888
889        // E[log p(y|X,beta,tau)]
890        let diff =
891            yty - 2.0 * self.mean_beta.dot(xty) + self.mean_beta.dot(&xtx.dot(&self.mean_beta));
892        let trace_term = (xtx * &self.cov_beta).sum();
893        let likelihood_term = 0.5 * n_samples_ * expected_log_tau
894            - 0.5 * n_samples_ * (2.0_f64 * PI).ln()
895            - 0.5 * expected_tau * (diff + trace_term);
896
897        // E[log p(beta)]
898        let beta_diff = &self.mean_beta - &self.prior_mean_beta;
899        let beta_quad = beta_diff.dot(&prior_precision.dot(&beta_diff));
900        let beta_trace = (prior_precision * &self.cov_beta).sum();
901
902        let prior_det = scirs2_linalg::det(&prior_precision.view(), None).map_err(|e| {
903            StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
904        })?;
905
906        let beta_prior_term = 0.5 * prior_det.ln()
907            - 0.5 * self.n_features as f64 * (2.0_f64 * PI).ln()
908            - 0.5 * (beta_quad + beta_trace);
909
910        // E[log p(tau)]
911        let tau_prior_term = self.priorshape_tau * self.prior_rate_tau.ln()
912            - lgamma(self.priorshape_tau)
913            + (self.priorshape_tau - 1.0) * expected_log_tau
914            - self.prior_rate_tau * expected_tau;
915
916        // -E[log q(beta)]
917        let var_det = scirs2_linalg::det(&self.cov_beta.view(), None).map_err(|e| {
918            StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
919        })?;
920        let beta_entropy =
921            0.5 * self.n_features as f64 * (1.0 + (2.0_f64 * PI).ln()) + 0.5 * var_det.ln();
922
923        // -E[log q(tau)]
924        let tau_entropy = self.shape_tau - self.rate_tau.ln()
925            + lgamma(self.shape_tau)
926            + (1.0 - self.shape_tau) * digamma(self.shape_tau);
927
928        Ok(likelihood_term + beta_prior_term + tau_prior_term + beta_entropy + tau_entropy)
929    }
930
931    /// Predict on new data
932    pub fn predict(
933        &self,
934        x: ArrayView2<f64>,
935        result: &VariationalRegressionResult,
936    ) -> Result<VariationalPredictionResult> {
937        checkarray_finite(&x, "x")?;
938        let (n_test, n_features) = x.dim();
939
940        if n_features != result.n_features {
941            return Err(StatsError::DimensionMismatch(format!(
942                "x has {} features, expected {}",
943                n_features, result.n_features
944            )));
945        }
946
947        // Center test data if model was fit with intercept
948        let x_centered = if let Some(ref x_mean) = result.x_mean {
949            let mut x_c = x.to_owned();
950            for mut row in x_c.rows_mut() {
951                row -= x_mean;
952            }
953            x_c
954        } else {
955            x.to_owned()
956        };
957
958        // Predictive mean
959        let y_pred_centered = x_centered.dot(&result.mean_beta);
960        let y_pred = if let Some(y_mean) = result.y_mean {
961            &y_pred_centered + y_mean
962        } else {
963            y_pred_centered.clone()
964        };
965
966        // Predictive variance
967        let expected_noise_variance = result.rate_tau / result.shape_tau;
968        let mut predictive_variance = Array1::zeros(n_test);
969
970        for i in 0..n_test {
971            let x_row = x_centered.row(i);
972            let model_variance = x_row.dot(&result.cov_beta.dot(&x_row));
973            predictive_variance[i] = expected_noise_variance + model_variance;
974        }
975
976        Ok(VariationalPredictionResult {
977            mean: y_pred,
978            variance: predictive_variance.clone(),
979            model_uncertainty: predictive_variance.mapv(|v| (v - expected_noise_variance).max(0.0)),
980            noise_variance: expected_noise_variance,
981        })
982    }
983}
984
985/// Results from variational Bayesian regression
986#[derive(Debug, Clone)]
987pub struct VariationalRegressionResult {
988    /// Posterior mean of coefficients
989    pub mean_beta: Array1<f64>,
990    /// Posterior covariance of coefficients
991    pub cov_beta: Array2<f64>,
992    /// Posterior shape parameter for precision
993    pub shape_tau: f64,
994    /// Posterior rate parameter for precision
995    pub rate_tau: f64,
996    /// Final ELBO value
997    pub elbo: f64,
998    /// ELBO history during optimization
999    pub elbo_history: Vec<f64>,
1000    /// Number of training samples
1001    pub n_samples_: usize,
1002    /// Number of features
1003    pub n_features: usize,
1004    /// Training data mean (for centering)
1005    pub x_mean: Option<Array1<f64>>,
1006    /// Training target mean (for centering)
1007    pub y_mean: Option<f64>,
1008    /// Whether optimization converged
1009    pub converged: bool,
1010}
1011
1012impl VariationalRegressionResult {
1013    /// Get posterior standard deviations of coefficients
1014    pub fn std_beta(&self) -> Array1<f64> {
1015        self.cov_beta.diag().mapv(f64::sqrt)
1016    }
1017
1018    /// Get posterior mean and standard deviation of noise precision
1019    pub fn precision_stats(&self) -> (f64, f64) {
1020        let mean = self.shape_tau / self.rate_tau;
1021        let variance = self.shape_tau / (self.rate_tau * self.rate_tau);
1022        (mean, variance.sqrt())
1023    }
1024
1025    /// Compute credible intervals for coefficients
1026    pub fn credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
1027        check_probability(confidence, "confidence")?;
1028
1029        let n_features = self.mean_beta.len();
1030        let mut intervals = Array2::zeros((n_features, 2));
1031        let alpha = (1.0 - confidence) / 2.0;
1032
1033        // Use normal approximation for coefficients
1034        for i in 0..n_features {
1035            let mean = self.mean_beta[i];
1036            let std = self.cov_beta[[i, i]].sqrt();
1037
1038            // Using standard normal quantiles (approximate)
1039            let z_critical = normal_ppf(1.0 - alpha)?;
1040            intervals[[i, 0]] = mean - z_critical * std;
1041            intervals[[i, 1]] = mean + z_critical * std;
1042        }
1043
1044        Ok(intervals)
1045    }
1046}
1047
1048/// Results from variational prediction
1049#[derive(Debug, Clone)]
1050pub struct VariationalPredictionResult {
1051    /// Predictive mean
1052    pub mean: Array1<f64>,
1053    /// Total predictive variance (model + noise)
1054    pub variance: Array1<f64>,
1055    /// Model uncertainty component
1056    pub model_uncertainty: Array1<f64>,
1057    /// Noise variance
1058    pub noise_variance: f64,
1059}
1060
1061impl VariationalPredictionResult {
1062    /// Get predictive standard deviations
1063    pub fn std(&self) -> Array1<f64> {
1064        self.variance.mapv(f64::sqrt)
1065    }
1066
1067    /// Compute predictive credible intervals
1068    pub fn credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
1069        check_probability(confidence, "confidence")?;
1070
1071        let n_predictions = self.mean.len();
1072        let mut intervals = Array2::zeros((n_predictions, 2));
1073        let alpha = (1.0 - confidence) / 2.0;
1074
1075        let z_critical = normal_ppf(1.0 - alpha)?;
1076
1077        for i in 0..n_predictions {
1078            let mean = self.mean[i];
1079            let std = self.variance[i].sqrt();
1080            intervals[[i, 0]] = mean - z_critical * std;
1081            intervals[[i, 1]] = mean + z_critical * std;
1082        }
1083
1084        Ok(intervals)
1085    }
1086}
1087
1088// ============================================================================
1089// Automatic Relevance Determination (existing)
1090// ============================================================================
1091
1092/// Automatic Relevance Determination with Variational Inference
1093///
1094/// Uses sparse priors to perform automatic feature selection
1095#[derive(Debug, Clone)]
1096pub struct VariationalARD {
1097    /// Variational mean for coefficients
1098    pub mean_beta: Array1<f64>,
1099    /// Variational variance for coefficients (diagonal)
1100    pub var_beta: Array1<f64>,
1101    /// Variational parameters for precision (alpha)
1102    pub shape_alpha: Array1<f64>,
1103    pub rate_alpha: Array1<f64>,
1104    /// Variational parameters for noise precision
1105    pub shape_tau: f64,
1106    pub rate_tau: f64,
1107    /// Prior parameters
1108    pub priorshape_alpha: f64,
1109    pub prior_rate_alpha: f64,
1110    pub priorshape_tau: f64,
1111    pub prior_rate_tau: f64,
1112    /// Model parameters
1113    pub n_features: usize,
1114    pub fit_intercept: bool,
1115}
1116
1117impl VariationalARD {
1118    /// Create new Variational ARD model
1119    pub fn new(n_features: usize, fit_intercept: bool) -> Result<Self> {
1120        check_positive(n_features, "n_features")?;
1121
1122        // Weakly informative priors
1123        let priorshape_alpha = 1e-3;
1124        let prior_rate_alpha = 1e-3;
1125        let priorshape_tau = 1e-3;
1126        let prior_rate_tau = 1e-3;
1127
1128        Ok(Self {
1129            mean_beta: Array1::zeros(n_features),
1130            var_beta: Array1::from_elem(n_features, 1.0),
1131            shape_alpha: Array1::from_elem(n_features, priorshape_alpha),
1132            rate_alpha: Array1::from_elem(n_features, prior_rate_alpha),
1133            shape_tau: priorshape_tau,
1134            rate_tau: prior_rate_tau,
1135            priorshape_alpha,
1136            prior_rate_alpha,
1137            priorshape_tau,
1138            prior_rate_tau,
1139            n_features,
1140            fit_intercept,
1141        })
1142    }
1143
1144    /// Fit ARD model using variational inference
1145    pub fn fit(
1146        &mut self,
1147        x: ArrayView2<f64>,
1148        y: ArrayView1<f64>,
1149        max_iter: usize,
1150        tol: f64,
1151    ) -> Result<VariationalARDResult> {
1152        checkarray_finite(&x, "x")?;
1153        checkarray_finite(&y, "y")?;
1154        check_positive(max_iter, "max_iter")?;
1155        check_positive(tol, "tol")?;
1156
1157        let (n_samples_, n_features) = x.dim();
1158        if y.len() != n_samples_ {
1159            return Err(StatsError::DimensionMismatch(format!(
1160                "y length ({}) must match x rows ({})",
1161                y.len(),
1162                n_samples_
1163            )));
1164        }
1165
1166        // Center data if fitting intercept
1167        let (x_centered, y_centered, x_mean, y_mean) = if self.fit_intercept {
1168            let x_mean = x.mean_axis(Axis(0)).expect("Operation failed");
1169            let y_mean = y.mean();
1170
1171            let mut x_centered = x.to_owned();
1172            for mut row in x_centered.rows_mut() {
1173                row -= &x_mean;
1174            }
1175            let y_centered = &y.to_owned() - y_mean;
1176
1177            (x_centered, y_centered, Some(x_mean), Some(y_mean))
1178        } else {
1179            (x.to_owned(), y.to_owned(), None, None)
1180        };
1181
1182        // Precompute matrices
1183        let xtx = x_centered.t().dot(&x_centered);
1184        let xty = x_centered.t().dot(&y_centered);
1185        let yty = y_centered.dot(&y_centered);
1186
1187        let mut prev_elbo = f64::NEG_INFINITY;
1188        let mut elbo_history = Vec::new();
1189
1190        for _iter in 0..max_iter {
1191            // Update q(beta)
1192            self.update_beta_ard(&xtx, &xty)?;
1193
1194            // Update q(alpha)
1195            self.update_alpha_ard()?;
1196
1197            // Update q(tau)
1198            self.update_tau_ard(n_samples_ as f64, &xtx, yty)?;
1199
1200            // Compute ELBO
1201            let elbo = self.compute_elbo_ard(n_samples_ as f64, &xtx, &xty, yty)?;
1202            elbo_history.push(elbo);
1203
1204            // Check convergence
1205            if _iter > 0 && (elbo - prev_elbo).abs() < tol {
1206                break;
1207            }
1208
1209            // Prune irrelevant features
1210            if _iter % 10 == 0 {
1211                self.prune_features()?;
1212            }
1213
1214            prev_elbo = elbo;
1215        }
1216
1217        Ok(VariationalARDResult {
1218            mean_beta: self.mean_beta.clone(),
1219            var_beta: self.var_beta.clone(),
1220            shape_alpha: self.shape_alpha.clone(),
1221            rate_alpha: self.rate_alpha.clone(),
1222            shape_tau: self.shape_tau,
1223            rate_tau: self.rate_tau,
1224            elbo: prev_elbo,
1225            elbo_history: elbo_history.clone(),
1226            n_samples_,
1227            n_features: self.n_features,
1228            x_mean,
1229            y_mean,
1230            converged: elbo_history.len() < max_iter,
1231        })
1232    }
1233
1234    /// Update variational distribution for beta in ARD model
1235    fn update_beta_ard(&mut self, xtx: &Array2<f64>, xty: &Array1<f64>) -> Result<()> {
1236        let expected_tau = self.shape_tau / self.rate_tau;
1237        let expected_alpha = &self.shape_alpha / &self.rate_alpha;
1238
1239        // Update variance (diagonal approximation)
1240        for i in 0..self.n_features {
1241            let precision_i = expected_alpha[i] + expected_tau * xtx[[i, i]];
1242            self.var_beta[i] = 1.0 / precision_i;
1243        }
1244
1245        // Update mean
1246        for i in 0..self.n_features {
1247            let sum_j = (0..self.n_features)
1248                .filter(|&j| j != i)
1249                .map(|j| xtx[[i, j]] * self.mean_beta[j])
1250                .sum::<f64>();
1251
1252            self.mean_beta[i] = expected_tau * self.var_beta[i] * (xty[i] - sum_j);
1253        }
1254
1255        Ok(())
1256    }
1257
1258    /// Update variational distribution for alpha (precision parameters)
1259    fn update_alpha_ard(&mut self) -> Result<()> {
1260        for i in 0..self.n_features {
1261            self.shape_alpha[i] = self.priorshape_alpha + 0.5;
1262            self.rate_alpha[i] =
1263                self.prior_rate_alpha + 0.5 * (self.mean_beta[i].powi(2) + self.var_beta[i]);
1264        }
1265
1266        Ok(())
1267    }
1268
1269    /// Update variational distribution for tau (noise precision)
1270    fn update_tau_ard(&mut self, n_samples_: f64, xtx: &Array2<f64>, yty: f64) -> Result<()> {
1271        self.shape_tau = self.priorshape_tau + n_samples_ / 2.0;
1272
1273        let mut quadratic_term = 0.0;
1274        for i in 0..self.n_features {
1275            for j in 0..self.n_features {
1276                if i == j {
1277                    quadratic_term += xtx[[i, j]] * (self.mean_beta[i].powi(2) + self.var_beta[i]);
1278                } else {
1279                    quadratic_term += xtx[[i, j]] * self.mean_beta[i] * self.mean_beta[j];
1280                }
1281            }
1282        }
1283
1284        self.rate_tau = self.prior_rate_tau
1285            + 0.5 * (yty - 2.0 * self.mean_beta.dot(&xtx.dot(&self.mean_beta)) + quadratic_term);
1286
1287        Ok(())
1288    }
1289
1290    /// Compute ELBO for ARD model
1291    fn compute_elbo_ard(
1292        &self,
1293        n_samples_: f64,
1294        xtx: &Array2<f64>,
1295        xty: &Array1<f64>,
1296        yty: f64,
1297    ) -> Result<f64> {
1298        let expected_tau = self.shape_tau / self.rate_tau;
1299        let expected_log_tau = digamma(self.shape_tau) - self.rate_tau.ln();
1300
1301        // Likelihood term
1302        let mut quadratic_form = yty - 2.0 * self.mean_beta.dot(xty);
1303        for i in 0..self.n_features {
1304            for j in 0..self.n_features {
1305                if i == j {
1306                    quadratic_form += xtx[[i, j]] * (self.mean_beta[i].powi(2) + self.var_beta[i]);
1307                } else {
1308                    quadratic_form += xtx[[i, j]] * self.mean_beta[i] * self.mean_beta[j];
1309                }
1310            }
1311        }
1312
1313        let likelihood_term = 0.5 * n_samples_ * expected_log_tau
1314            - 0.5 * n_samples_ * (2.0_f64 * PI).ln()
1315            - 0.5 * expected_tau * quadratic_form;
1316
1317        // Prior terms
1318        let mut prior_term = 0.0;
1319        for i in 0..self.n_features {
1320            let expected_alpha_i = self.shape_alpha[i] / self.rate_alpha[i];
1321            let expected_log_alpha_i = digamma(self.shape_alpha[i]) - self.rate_alpha[i].ln();
1322
1323            prior_term += 0.5 * expected_log_alpha_i
1324                - 0.5 * (2.0_f64 * PI).ln()
1325                - 0.5 * expected_alpha_i * (self.mean_beta[i].powi(2) + self.var_beta[i]);
1326        }
1327
1328        // Entropy terms
1329        let mut entropy_term = 0.0;
1330        for i in 0..self.n_features {
1331            entropy_term += 0.5 * (1.0 + (2.0 * PI * self.var_beta[i]).ln());
1332        }
1333
1334        Ok(likelihood_term + prior_term + entropy_term)
1335    }
1336
1337    /// Prune features with small precision (large variance in prior)
1338    fn prune_features(&mut self) -> Result<()> {
1339        let threshold = 1e12; // Large precision = irrelevant feature
1340
1341        for i in 0..self.n_features {
1342            let expected_alpha = self.shape_alpha[i] / self.rate_alpha[i];
1343            if expected_alpha > threshold {
1344                // Feature is irrelevant, set to zero
1345                self.mean_beta[i] = 0.0;
1346                self.var_beta[i] = 1e-12;
1347            }
1348        }
1349
1350        Ok(())
1351    }
1352
1353    /// Get relevance scores for features
1354    pub fn feature_relevance(&self) -> Array1<f64> {
1355        let expected_alpha = &self.shape_alpha / &self.rate_alpha;
1356        // Relevance is inverse of precision (features with low precision are more relevant)
1357        expected_alpha.mapv(|alpha| 1.0 / alpha)
1358    }
1359}
1360
1361/// Results from Variational ARD
1362#[derive(Debug, Clone)]
1363pub struct VariationalARDResult {
1364    /// Posterior mean of coefficients
1365    pub mean_beta: Array1<f64>,
1366    /// Posterior variance of coefficients
1367    pub var_beta: Array1<f64>,
1368    /// Posterior shape parameters for feature precisions
1369    pub shape_alpha: Array1<f64>,
1370    /// Posterior rate parameters for feature precisions
1371    pub rate_alpha: Array1<f64>,
1372    /// Posterior shape parameter for noise precision
1373    pub shape_tau: f64,
1374    /// Posterior rate parameter for noise precision
1375    pub rate_tau: f64,
1376    /// Final ELBO value
1377    pub elbo: f64,
1378    /// ELBO history
1379    pub elbo_history: Vec<f64>,
1380    /// Number of training samples
1381    pub n_samples_: usize,
1382    /// Number of features
1383    pub n_features: usize,
1384    /// Training data mean
1385    pub x_mean: Option<Array1<f64>>,
1386    /// Training target mean
1387    pub y_mean: Option<f64>,
1388    /// Whether optimization converged
1389    pub converged: bool,
1390}
1391
1392impl VariationalARDResult {
1393    /// Get selected features based on relevance threshold
1394    pub fn selected_features(&self, threshold: f64) -> Vec<usize> {
1395        let expected_alpha = &self.shape_alpha / &self.rate_alpha;
1396        expected_alpha
1397            .iter()
1398            .enumerate()
1399            .filter(|(_, &alpha)| alpha < threshold) // Low precision = high relevance
1400            .map(|(i, _)| i)
1401            .collect()
1402    }
1403
1404    /// Get feature importance scores
1405    pub fn feature_importance(&self) -> Array1<f64> {
1406        self.mean_beta.mapv(f64::abs)
1407    }
1408}
1409
1410// ============================================================================
1411// Shared helper functions
1412// ============================================================================
1413
1414/// Compute outer product of a vector with itself
1415pub(crate) fn outer_product(v: &Array1<f64>) -> Array2<f64> {
1416    let n = v.len();
1417    let mut result = Array2::zeros((n, n));
1418    for i in 0..n {
1419        for j in 0..n {
1420            result[[i, j]] = v[i] * v[j];
1421        }
1422    }
1423    result
1424}
1425
1426/// Approximate normal PPF using rational approximation (Beasley-Springer-Moro)
1427pub(crate) fn normal_ppf(p: f64) -> Result<f64> {
1428    if p <= 0.0 || p >= 1.0 {
1429        return Err(StatsError::InvalidArgument(
1430            "p must be between 0 and 1".to_string(),
1431        ));
1432    }
1433
1434    let a = [
1435        -3.969683028665376e+01,
1436        2.209460984245205e+02,
1437        -2.759285104469687e+02,
1438        1.383_577_518_672_69e2,
1439        -3.066479806614716e+01,
1440        2.506628277459239e+00,
1441    ];
1442
1443    let b = [
1444        -5.447609879822406e+01,
1445        1.615858368580409e+02,
1446        -1.556989798598866e+02,
1447        6.680131188771972e+01,
1448        -1.328068155288572e+01,
1449    ];
1450
1451    let c = [
1452        -7.784894002430293e-03,
1453        -3.223964580411365e-01,
1454        -2.400758277161838e+00,
1455        -2.549732539343734e+00,
1456        4.374664141464968e+00,
1457        2.938163982698783e+00,
1458    ];
1459
1460    let d = [
1461        7.784695709041462e-03,
1462        3.224671290700398e-01,
1463        2.445134137142996e+00,
1464        3.754408661907416e+00,
1465    ];
1466
1467    let p_low = 0.02425;
1468    let p_high = 1.0 - p_low;
1469
1470    if p < p_low {
1471        let q = (-2.0 * p.ln()).sqrt();
1472        Ok(
1473            (((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
1474                / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0),
1475        )
1476    } else if p <= p_high {
1477        let q = p - 0.5;
1478        let r = q * q;
1479        Ok(
1480            (((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5]) * q
1481                / (((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + 1.0),
1482        )
1483    } else {
1484        let q = (-2.0 * (1.0 - p).ln()).sqrt();
1485        Ok(
1486            (-((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
1487                / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0),
1488        )
1489    }
1490}
1491
1492/// Digamma function (approximate)
1493pub(crate) fn digamma(x: f64) -> f64 {
1494    if x <= 0.0 {
1495        return f64::NEG_INFINITY;
1496    }
1497
1498    if x < 8.0 {
1499        return digamma(x + 1.0) - 1.0 / x;
1500    }
1501
1502    let inv_x = 1.0 / x;
1503    let inv_x2 = inv_x * inv_x;
1504
1505    x.ln() - 0.5 * inv_x - inv_x2 / 12.0 + inv_x2 * inv_x2 / 120.0
1506        - inv_x2 * inv_x2 * inv_x2 / 252.0
1507}
1508
1509/// Log gamma function (approximate using Stirling's series)
1510pub(crate) fn lgamma(x: f64) -> f64 {
1511    if x <= 0.0 {
1512        return f64::NEG_INFINITY;
1513    }
1514
1515    // Use reflection formula for x < 0.5 to improve accuracy
1516    if x < 0.5 {
1517        // Reflection: lgamma(x) = ln(pi/sin(pi*x)) - lgamma(1-x)
1518        return (PI / (PI * x).sin()).ln() - lgamma(1.0 - x);
1519    }
1520
1521    // Lanczos approximation (g=7, n=9) -- accurate to ~15 significant digits
1522    // Coefficients from Paul Godfrey's implementation
1523    const G: f64 = 7.0;
1524    const C: [f64; 9] = [
1525        0.99999999999980993,
1526        676.5203681218851,
1527        -1259.1392167224028,
1528        771.323_428_777_653_1,
1529        -176.615_029_162_140_6,
1530        12.507_343_278_686_905,
1531        -0.138_571_095_265_720_12,
1532        9.984_369_578_019_572e-6,
1533        1.505_632_735_149_311_6e-7,
1534    ];
1535
1536    let x = x - 1.0;
1537    let mut a = C[0];
1538    let t = x + G + 0.5;
1539    for (i, &c) in C[1..].iter().enumerate() {
1540        a += c / (x + (i as f64 + 1.0));
1541    }
1542    0.5 * (2.0 * PI).ln() + (x + 0.5) * t.ln() - t + a.ln()
1543}
1544
1545/// Trigamma function (derivative of digamma)
1546pub(crate) fn trigamma(x: f64) -> f64 {
1547    if x <= 0.0 {
1548        return f64::INFINITY;
1549    }
1550
1551    if x < 8.0 {
1552        return trigamma(x + 1.0) + 1.0 / (x * x);
1553    }
1554
1555    let inv_x = 1.0 / x;
1556    let inv_x2 = inv_x * inv_x;
1557
1558    inv_x + 0.5 * inv_x2 + inv_x2 * inv_x / 6.0 - inv_x2 * inv_x2 * inv_x / 30.0
1559        + inv_x2 * inv_x2 * inv_x2 * inv_x / 42.0
1560}
1561
1562// ============================================================================
1563// Tests
1564// ============================================================================
1565
1566#[cfg(test)]
1567mod tests {
1568    use super::*;
1569    use scirs2_core::ndarray::Array2;
1570
1571    #[test]
1572    fn test_mean_field_gaussian_creation() {
1573        let mf = MeanFieldGaussian::new(5).expect("should create mean-field Gaussian");
1574        assert_eq!(mf.dim, 5);
1575        assert_eq!(mf.means.len(), 5);
1576        assert_eq!(mf.log_stds.len(), 5);
1577        assert_eq!(mf.n_params(), 10);
1578    }
1579
1580    #[test]
1581    fn test_mean_field_gaussian_entropy() {
1582        let mf = MeanFieldGaussian::new(2).expect("should create");
1583        let entropy = mf.entropy();
1584        // For standard normal in 2D: 2 * 0.5 * (1 + log(2*pi))
1585        let expected = 2.0 * 0.5 * (1.0 + (2.0 * PI).ln());
1586        assert!((entropy - expected).abs() < 1e-10);
1587    }
1588
1589    #[test]
1590    fn test_mean_field_gaussian_sample() {
1591        let mf = MeanFieldGaussian::new(3).expect("should create");
1592        let epsilon = Array1::from_vec(vec![0.5, -0.3, 1.0]);
1593        let sample = mf.sample(&epsilon).expect("should sample");
1594        assert_eq!(sample.len(), 3);
1595        // With mean=0 and std=1, sample should equal epsilon
1596        for i in 0..3 {
1597            assert!((sample[i] - epsilon[i]).abs() < 1e-10);
1598        }
1599    }
1600
1601    #[test]
1602    fn test_mean_field_gaussian_params_roundtrip() {
1603        let mut mf = MeanFieldGaussian::new(3).expect("should create");
1604        let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 0.5, -0.3, 0.1]);
1605        mf.set_params(&params).expect("should set params");
1606        let retrieved = mf.get_params();
1607        for i in 0..6 {
1608            assert!((retrieved[i] - params[i]).abs() < 1e-10);
1609        }
1610    }
1611
1612    #[test]
1613    fn test_full_rank_gaussian_creation() {
1614        let fr = FullRankGaussian::new(3).expect("should create full-rank Gaussian");
1615        assert_eq!(fr.dim, 3);
1616        assert_eq!(fr.mean.len(), 3);
1617        // n_params = d + d*(d+1)/2 = 3 + 6 = 9
1618        assert_eq!(fr.n_params(), 9);
1619    }
1620
1621    #[test]
1622    fn test_full_rank_gaussian_entropy() {
1623        let fr = FullRankGaussian::new(2).expect("should create");
1624        let entropy = fr.entropy();
1625        // For identity covariance: 0.5*d*(1+log(2*pi)) + 0 = d * 0.5*(1+log(2pi))
1626        let expected = 2.0 * 0.5 * (1.0 + (2.0 * PI).ln());
1627        assert!((entropy - expected).abs() < 1e-10);
1628    }
1629
1630    #[test]
1631    fn test_full_rank_gaussian_sample() {
1632        let fr = FullRankGaussian::new(2).expect("should create");
1633        let epsilon = Array1::from_vec(vec![1.0, -1.0]);
1634        let sample = fr.sample(&epsilon).expect("should sample");
1635        assert_eq!(sample.len(), 2);
1636        // With identity chol factor and zero mean, sample = epsilon
1637        for i in 0..2 {
1638            assert!((sample[i] - epsilon[i]).abs() < 1e-10);
1639        }
1640    }
1641
1642    #[test]
1643    fn test_normalizing_flow_creation() {
1644        let nf = NormalizingFlowVI::new(3, 2).expect("should create");
1645        assert_eq!(nf.dim, 3);
1646        assert_eq!(nf.flows.len(), 2);
1647    }
1648
1649    #[test]
1650    fn test_normalizing_flow_transform() {
1651        let nf = NormalizingFlowVI::new(2, 1).expect("should create");
1652        let z0 = Array1::from_vec(vec![0.5, -0.5]);
1653        let (z_k, log_det) = nf.transform(&z0).expect("should transform");
1654        assert_eq!(z_k.len(), 2);
1655        assert!(log_det.is_finite());
1656    }
1657
1658    #[test]
1659    fn test_diagnostics() {
1660        let mut diag = VariationalDiagnostics::new();
1661        diag.record_elbo(-100.0);
1662        diag.record_elbo(-90.0);
1663        diag.record_elbo(-85.0);
1664        diag.record_gradient_norm(10.0);
1665        diag.record_gradient_norm(5.0);
1666
1667        assert_eq!(diag.n_iterations, 3);
1668        assert!(!diag.check_elbo_convergence(1.0));
1669        assert!(diag.check_elbo_convergence(10.0));
1670
1671        let summary = diag.elbo_summary();
1672        assert!((summary.min - (-100.0)).abs() < 1e-10);
1673        assert!((summary.max - (-85.0)).abs() < 1e-10);
1674        assert!(summary.monotonic);
1675    }
1676
1677    #[test]
1678    fn test_variational_bayesian_regression() {
1679        // Simple regression: y = 2*x + 1 + noise
1680        let n = 50;
1681        let mut x_data = Vec::with_capacity(n);
1682        let mut y_data = Vec::with_capacity(n);
1683
1684        for i in 0..n {
1685            let xi = i as f64 / n as f64;
1686            x_data.push(xi);
1687            y_data.push(2.0 * xi + 1.0 + 0.1 * ((i * 7 % 13) as f64 - 6.0) / 6.0);
1688        }
1689
1690        let x = Array2::from_shape_fn((n, 1), |(i, _)| x_data[i]);
1691        let y = Array1::from_vec(y_data);
1692
1693        let mut model = VariationalBayesianRegression::new(1, true).expect("should create model");
1694        let result = model
1695            .fit(x.view(), y.view(), 100, 1e-6)
1696            .expect("should fit");
1697
1698        // Check that coefficient is close to 2.0
1699        assert!(
1700            (result.mean_beta[0] - 2.0).abs() < 0.5,
1701            "beta should be close to 2.0, got {}",
1702            result.mean_beta[0]
1703        );
1704    }
1705
1706    #[test]
1707    fn test_trigamma() {
1708        // trigamma(1) = pi^2/6
1709        let expected = PI * PI / 6.0;
1710        let computed = trigamma(1.0);
1711        assert!(
1712            (computed - expected).abs() < 0.01,
1713            "trigamma(1) should be close to pi^2/6, got {}",
1714            computed
1715        );
1716    }
1717}