Skip to main content

scirs2_stats/variational/
normalizing_flow.rs

1//! Normalizing Flows for Variational Inference
2//!
3//! Implements invertible transformations that map a simple base distribution
4//! (e.g., standard Gaussian) to a more flexible posterior approximation.
5//!
6//! Supports:
7//! - **Planar flow**: `f(z) = z + u * tanh(w^T z + b)` (Rezende & Mohamed 2015)
8//! - **Radial flow**: `f(z) = z + beta * (z - z0) / (alpha + ||z - z0||)` (Rezende & Mohamed 2015)
9//! - **Flow chains**: Compose multiple flows `z_K = f_K . ... . f_1(z_0)`
10//! - **ELBO with flow**: `log p(x, z_K) - log q_0(z_0) + sum log|det(df_k/dz_{k-1})|`
11//!
12//! These flows can be used to enhance ADVI by replacing mean-field with flow-based posteriors.
13
14use crate::error::{StatsError, StatsResult};
15use scirs2_core::ndarray::Array1;
16use std::f64::consts::PI;
17
18use super::{PosteriorResult, VariationalInference};
19
20// ============================================================================
21// Flow Types
22// ============================================================================
23
24/// Type of normalizing flow layer
25#[derive(Debug, Clone)]
26pub enum FlowLayer {
27    /// Planar flow: f(z) = z + u * tanh(w^T z + b)
28    Planar {
29        /// Weight vector w (dim)
30        w: Array1<f64>,
31        /// Scale vector u (dim)
32        u: Array1<f64>,
33        /// Bias scalar
34        b: f64,
35    },
36    /// Radial flow: f(z) = z + beta * (z - z0) / (alpha + ||z - z0||)
37    Radial {
38        /// Center point z0 (dim)
39        z0: Array1<f64>,
40        /// Scale parameter alpha > 0
41        alpha: f64,
42        /// Magnitude parameter beta
43        beta: f64,
44    },
45}
46
47impl FlowLayer {
48    /// Create a new planar flow layer with random initialization
49    pub fn new_planar(dim: usize, seed: u64) -> Self {
50        let golden = 1.618033988749895_f64;
51        let plastic = 1.324717957244746_f64;
52
53        let w = Array1::from_shape_fn(dim, |i| {
54            let u1 = ((seed as f64 * golden + i as f64 * plastic + 0.3) % 1.0)
55                .abs()
56                .max(1e-10)
57                .min(1.0 - 1e-10);
58            let u2 = ((seed as f64 * plastic + i as f64 * golden + 0.7) % 1.0)
59                .abs()
60                .max(1e-10)
61                .min(1.0 - 1e-10);
62            let r = (-2.0 * u1.ln()).sqrt();
63            r * (2.0 * PI * u2).cos() * 0.1
64        });
65
66        let u = Array1::from_shape_fn(dim, |i| {
67            let u1 = (((seed + 100) as f64 * golden + i as f64 * plastic + 0.1) % 1.0)
68                .abs()
69                .max(1e-10)
70                .min(1.0 - 1e-10);
71            let u2 = (((seed + 100) as f64 * plastic + i as f64 * golden + 0.9) % 1.0)
72                .abs()
73                .max(1e-10)
74                .min(1.0 - 1e-10);
75            let r = (-2.0 * u1.ln()).sqrt();
76            r * (2.0 * PI * u2).cos() * 0.1
77        });
78
79        let b_val = {
80            let u1 = ((seed as f64 * 0.37 + 0.5) % 1.0)
81                .abs()
82                .max(1e-10)
83                .min(1.0 - 1e-10);
84            let u2 = ((seed as f64 * 0.73 + 0.5) % 1.0)
85                .abs()
86                .max(1e-10)
87                .min(1.0 - 1e-10);
88            let r = (-2.0 * u1.ln()).sqrt();
89            r * (2.0 * PI * u2).cos() * 0.1
90        };
91
92        FlowLayer::Planar { w, u, b: b_val }
93    }
94
95    /// Create a new radial flow layer with random initialization
96    pub fn new_radial(dim: usize, seed: u64) -> Self {
97        let golden = 1.618033988749895_f64;
98        let plastic = 1.324717957244746_f64;
99
100        let z0 = Array1::from_shape_fn(dim, |i| {
101            let u1 = (((seed + 200) as f64 * golden + i as f64 * plastic + 0.2) % 1.0)
102                .abs()
103                .max(1e-10)
104                .min(1.0 - 1e-10);
105            let u2 = (((seed + 200) as f64 * plastic + i as f64 * golden + 0.8) % 1.0)
106                .abs()
107                .max(1e-10)
108                .min(1.0 - 1e-10);
109            let r = (-2.0 * u1.ln()).sqrt();
110            r * (2.0 * PI * u2).cos() * 0.1
111        });
112
113        FlowLayer::Radial {
114            z0,
115            alpha: 1.0,
116            beta: 0.1,
117        }
118    }
119
120    /// Apply the flow transformation: f(z) and compute log|det(df/dz)|
121    ///
122    /// Returns (f(z), log|det J|)
123    pub fn forward(&self, z: &Array1<f64>) -> StatsResult<(Array1<f64>, f64)> {
124        match self {
125            FlowLayer::Planar { w, u, b } => {
126                let dim = z.len();
127                if w.len() != dim || u.len() != dim {
128                    return Err(StatsError::DimensionMismatch(format!(
129                        "Flow dimension mismatch: z={}, w={}, u={}",
130                        dim,
131                        w.len(),
132                        u.len()
133                    )));
134                }
135
136                // Enforce invertibility: u_hat = u + (m(w^T u) - w^T u) * w / ||w||^2
137                // where m(x) = -1 + softplus(x) = -1 + log(1 + exp(x))
138                let u_hat = enforce_planar_invertibility(w, u);
139
140                let wtz = w.dot(z) + b;
141                let tanh_wtz = wtz.tanh();
142
143                // f(z) = z + u_hat * tanh(w^T z + b)
144                let fz = z + &(&u_hat * tanh_wtz);
145
146                // log|det J| = log|1 + u_hat^T * w * (1 - tanh^2(w^T z + b))|
147                let dtanh = 1.0 - tanh_wtz * tanh_wtz;
148                let psi = w * dtanh;
149                let det_term = 1.0 + u_hat.dot(&psi);
150
151                let log_det = det_term.abs().max(1e-15).ln();
152
153                Ok((fz, log_det))
154            }
155            FlowLayer::Radial { z0, alpha, beta } => {
156                let dim = z.len();
157                if z0.len() != dim {
158                    return Err(StatsError::DimensionMismatch(format!(
159                        "Flow dimension mismatch: z={}, z0={}",
160                        dim,
161                        z0.len()
162                    )));
163                }
164
165                let diff = z - z0;
166                let r = diff.dot(&diff).sqrt().max(1e-10);
167                let alpha_pos = alpha.abs().max(1e-6);
168
169                // Enforce beta >= -alpha to ensure invertibility
170                let beta_hat = -alpha_pos + softplus(*beta + alpha_pos);
171
172                let h = 1.0 / (alpha_pos + r);
173                let h_prime = -1.0 / ((alpha_pos + r) * (alpha_pos + r));
174
175                // f(z) = z + beta_hat * h(r) * (z - z0)
176                let fz = z + &(&diff * (beta_hat * h));
177
178                // log|det J| = (d-1) * log(1 + beta_hat * h)
179                //             + log(1 + beta_hat * h + beta_hat * h' * r)
180                let d = dim as f64;
181                let term1 = 1.0 + beta_hat * h;
182                let term2 = 1.0 + beta_hat * h + beta_hat * h_prime * r;
183
184                let log_det = (d - 1.0) * term1.abs().max(1e-15).ln() + term2.abs().max(1e-15).ln();
185
186                Ok((fz, log_det))
187            }
188        }
189    }
190
191    /// Get the total number of parameters for this flow layer
192    pub fn n_params(&self) -> usize {
193        match self {
194            FlowLayer::Planar { w, u, .. } => w.len() + u.len() + 1,
195            FlowLayer::Radial { z0, .. } => z0.len() + 2,
196        }
197    }
198
199    /// Get all parameters as a flat vector
200    pub fn get_params(&self) -> Array1<f64> {
201        match self {
202            FlowLayer::Planar { w, u, b } => {
203                let dim = w.len();
204                let mut params = Array1::zeros(2 * dim + 1);
205                for i in 0..dim {
206                    params[i] = w[i];
207                    params[dim + i] = u[i];
208                }
209                params[2 * dim] = *b;
210                params
211            }
212            FlowLayer::Radial { z0, alpha, beta } => {
213                let dim = z0.len();
214                let mut params = Array1::zeros(dim + 2);
215                for i in 0..dim {
216                    params[i] = z0[i];
217                }
218                params[dim] = *alpha;
219                params[dim + 1] = *beta;
220                params
221            }
222        }
223    }
224
225    /// Set parameters from a flat vector
226    pub fn set_params(&mut self, params: &Array1<f64>) -> StatsResult<()> {
227        match self {
228            FlowLayer::Planar { w, u, b } => {
229                let dim = w.len();
230                if params.len() != 2 * dim + 1 {
231                    return Err(StatsError::DimensionMismatch(format!(
232                        "Expected {} params, got {}",
233                        2 * dim + 1,
234                        params.len()
235                    )));
236                }
237                for i in 0..dim {
238                    w[i] = params[i];
239                    u[i] = params[dim + i];
240                }
241                *b = params[2 * dim];
242                Ok(())
243            }
244            FlowLayer::Radial { z0, alpha, beta } => {
245                let dim = z0.len();
246                if params.len() != dim + 2 {
247                    return Err(StatsError::DimensionMismatch(format!(
248                        "Expected {} params, got {}",
249                        dim + 2,
250                        params.len()
251                    )));
252                }
253                for i in 0..dim {
254                    z0[i] = params[i];
255                }
256                *alpha = params[dim];
257                *beta = params[dim + 1];
258                Ok(())
259            }
260        }
261    }
262}
263
264/// Enforce invertibility for planar flows by computing u_hat
265/// such that w^T u_hat >= -1
266fn enforce_planar_invertibility(w: &Array1<f64>, u: &Array1<f64>) -> Array1<f64> {
267    let wtu = w.dot(u);
268    let w_norm_sq = w.dot(w);
269    if w_norm_sq < 1e-15 {
270        return u.clone();
271    }
272    // m(x) = -1 + softplus(x) = -1 + log(1 + exp(x))
273    let m_wtu = -1.0 + softplus(wtu);
274    if (m_wtu - wtu).abs() < 1e-15 {
275        return u.clone();
276    }
277    u + &(w * ((m_wtu - wtu) / w_norm_sq))
278}
279
280/// Numerically stable softplus: log(1 + exp(x))
281fn softplus(x: f64) -> f64 {
282    if x > 20.0 {
283        x
284    } else if x < -20.0 {
285        x.exp()
286    } else {
287        (1.0 + x.exp()).ln()
288    }
289}
290
291// ============================================================================
292// Flow Chain
293// ============================================================================
294
295/// A chain of normalizing flow layers
296#[derive(Debug, Clone)]
297pub struct NormalizingFlowChain {
298    /// Ordered list of flow layers
299    pub layers: Vec<FlowLayer>,
300}
301
302impl NormalizingFlowChain {
303    /// Create a new flow chain with the given layers
304    pub fn new(layers: Vec<FlowLayer>) -> Self {
305        Self { layers }
306    }
307
308    /// Create a chain of planar flows
309    pub fn planar(dim: usize, n_layers: usize, seed: u64) -> Self {
310        let layers = (0..n_layers)
311            .map(|i| FlowLayer::new_planar(dim, seed + i as u64 * 7))
312            .collect();
313        Self { layers }
314    }
315
316    /// Create a chain of radial flows
317    pub fn radial(dim: usize, n_layers: usize, seed: u64) -> Self {
318        let layers = (0..n_layers)
319            .map(|i| FlowLayer::new_radial(dim, seed + i as u64 * 11))
320            .collect();
321        Self { layers }
322    }
323
324    /// Create a mixed chain alternating planar and radial flows
325    pub fn mixed(dim: usize, n_layers: usize, seed: u64) -> Self {
326        let layers = (0..n_layers)
327            .map(|i| {
328                if i % 2 == 0 {
329                    FlowLayer::new_planar(dim, seed + i as u64 * 13)
330                } else {
331                    FlowLayer::new_radial(dim, seed + i as u64 * 17)
332                }
333            })
334            .collect();
335        Self { layers }
336    }
337
338    /// Apply the full chain: z_K = f_K . ... . f_1(z_0)
339    ///
340    /// Returns (z_K, sum of log|det J_k|)
341    pub fn forward(&self, z0: &Array1<f64>) -> StatsResult<(Array1<f64>, f64)> {
342        let mut z = z0.clone();
343        let mut total_log_det = 0.0;
344
345        for layer in &self.layers {
346            let (z_new, log_det) = layer.forward(&z)?;
347            z = z_new;
348            total_log_det += log_det;
349        }
350
351        Ok((z, total_log_det))
352    }
353
354    /// Total number of flow parameters across all layers
355    pub fn n_params(&self) -> usize {
356        self.layers.iter().map(|l| l.n_params()).sum()
357    }
358
359    /// Get all flow parameters as a flat vector
360    pub fn get_params(&self) -> Array1<f64> {
361        let total = self.n_params();
362        let mut params = Array1::zeros(total);
363        let mut offset = 0;
364        for layer in &self.layers {
365            let lp = layer.get_params();
366            let n = lp.len();
367            for i in 0..n {
368                params[offset + i] = lp[i];
369            }
370            offset += n;
371        }
372        params
373    }
374
375    /// Set all flow parameters from a flat vector
376    pub fn set_params(&mut self, params: &Array1<f64>) -> StatsResult<()> {
377        let total = self.n_params();
378        if params.len() != total {
379            return Err(StatsError::DimensionMismatch(format!(
380                "Expected {} total flow params, got {}",
381                total,
382                params.len()
383            )));
384        }
385        let mut offset = 0;
386        for layer in &mut self.layers {
387            let n = layer.n_params();
388            let lp = Array1::from_shape_fn(n, |i| params[offset + i]);
389            layer.set_params(&lp)?;
390            offset += n;
391        }
392        Ok(())
393    }
394}
395
396// ============================================================================
397// Flow-enhanced Variational Inference
398// ============================================================================
399
400/// Configuration for flow-enhanced variational inference
401#[derive(Debug, Clone)]
402pub struct FlowViConfig {
403    /// Type of flow layers to use
404    pub flow_type: FlowType,
405    /// Number of flow layers
406    pub n_flow_layers: usize,
407    /// Number of MC samples for ELBO estimation
408    pub num_samples: usize,
409    /// Learning rate
410    pub learning_rate: f64,
411    /// Maximum iterations
412    pub max_iterations: usize,
413    /// Convergence tolerance
414    pub tolerance: f64,
415    /// Random seed
416    pub seed: u64,
417    /// Convergence window
418    pub convergence_window: usize,
419}
420
421/// Type of flow to use
422#[derive(Debug, Clone, Copy)]
423pub enum FlowType {
424    /// Planar flows only
425    Planar,
426    /// Radial flows only
427    Radial,
428    /// Alternating planar and radial
429    Mixed,
430}
431
432impl Default for FlowViConfig {
433    fn default() -> Self {
434        Self {
435            flow_type: FlowType::Planar,
436            n_flow_layers: 4,
437            num_samples: 10,
438            learning_rate: 0.01,
439            max_iterations: 5000,
440            tolerance: 1e-4,
441            seed: 42,
442            convergence_window: 50,
443        }
444    }
445}
446
447/// Flow-enhanced Variational Inference
448///
449/// Uses a normalizing flow on top of a mean-field Gaussian base distribution
450/// to produce a more flexible posterior approximation.
451///
452/// The ELBO becomes:
453/// ```text
454/// ELBO = E_{z_0 ~ q_0} [log p(x, z_K) - log q_0(z_0) + sum_k log|det J_k|]
455/// ```
456/// where z_K = f_K . ... . f_1(z_0) and q_0 = N(mu, diag(sigma^2)).
457#[derive(Debug, Clone)]
458pub struct FlowVi {
459    /// Configuration
460    pub config: FlowViConfig,
461}
462
463impl FlowVi {
464    /// Create a new flow-enhanced VI instance
465    pub fn new(config: FlowViConfig) -> Self {
466        Self { config }
467    }
468
469    /// Generate quasi-random standard normal samples
470    fn generate_epsilon(&self, dim: usize, seed: u64) -> Array1<f64> {
471        let golden = 1.618033988749895_f64;
472        let plastic = 1.324717957244746_f64;
473        Array1::from_shape_fn(dim, |i| {
474            let u1 = ((seed as f64 * golden + i as f64 * plastic) % 1.0)
475                .abs()
476                .max(1e-10)
477                .min(1.0 - 1e-10);
478            let u2 = ((seed as f64 * plastic + i as f64 * golden) % 1.0)
479                .abs()
480                .max(1e-10)
481                .min(1.0 - 1e-10);
482            let r = (-2.0 * u1.ln()).sqrt();
483            r * (2.0 * PI * u2).cos()
484        })
485    }
486}
487
488/// Adam state for flow VI (for all parameters: base + flow)
489#[derive(Debug, Clone)]
490struct FlowAdamState {
491    m: Array1<f64>,
492    v: Array1<f64>,
493    t: usize,
494    beta1: f64,
495    beta2: f64,
496    epsilon: f64,
497}
498
499impl FlowAdamState {
500    fn new(n: usize) -> Self {
501        Self {
502            m: Array1::zeros(n),
503            v: Array1::zeros(n),
504            t: 0,
505            beta1: 0.9,
506            beta2: 0.999,
507            epsilon: 1e-8,
508        }
509    }
510
511    fn update(&mut self, grad: &Array1<f64>) -> Array1<f64> {
512        self.t += 1;
513        let n = grad.len();
514        let mut dir = Array1::zeros(n);
515        for i in 0..n {
516            self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * grad[i];
517            self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * grad[i] * grad[i];
518            let m_hat = self.m[i] / (1.0 - self.beta1.powi(self.t as i32));
519            let v_hat = self.v[i] / (1.0 - self.beta2.powi(self.t as i32));
520            dir[i] = m_hat / (v_hat.sqrt() + self.epsilon);
521        }
522        dir
523    }
524}
525
526impl VariationalInference for FlowVi {
527    fn fit<F>(&mut self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
528    where
529        F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
530    {
531        if dim == 0 {
532            return Err(StatsError::InvalidArgument(
533                "Dimension must be at least 1".to_string(),
534            ));
535        }
536        if self.config.n_flow_layers == 0 {
537            return Err(StatsError::InvalidArgument(
538                "n_flow_layers must be at least 1".to_string(),
539            ));
540        }
541
542        // Initialize base distribution parameters: mu, log_sigma
543        let mut mu = Array1::zeros(dim);
544        let mut log_sigma = Array1::zeros(dim);
545
546        // Initialize flow chain
547        let mut flow = match self.config.flow_type {
548            FlowType::Planar => {
549                NormalizingFlowChain::planar(dim, self.config.n_flow_layers, self.config.seed)
550            }
551            FlowType::Radial => {
552                NormalizingFlowChain::radial(dim, self.config.n_flow_layers, self.config.seed)
553            }
554            FlowType::Mixed => {
555                NormalizingFlowChain::mixed(dim, self.config.n_flow_layers, self.config.seed)
556            }
557        };
558
559        // Total parameters: base (2*dim) + flow params
560        let n_base = 2 * dim;
561        let n_flow = flow.n_params();
562        let n_total = n_base + n_flow;
563        let fd_eps = 1e-4;
564
565        let mut adam = FlowAdamState::new(n_total);
566        let mut elbo_history = Vec::with_capacity(self.config.max_iterations);
567        let mut converged = false;
568
569        for iter in 0..self.config.max_iterations {
570            // Evaluate ELBO at current params
571            let elbo = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
572            elbo_history.push(elbo);
573
574            // Compute numerical gradient for all parameters
575            let mut full_grad = Array1::zeros(n_total);
576
577            // Gradient w.r.t. mu
578            for i in 0..dim {
579                let orig = mu[i];
580                mu[i] = orig + fd_eps;
581                let elbo_plus = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
582                mu[i] = orig - fd_eps;
583                let elbo_minus = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
584                mu[i] = orig;
585                full_grad[i] = (elbo_plus - elbo_minus) / (2.0 * fd_eps);
586            }
587
588            // Gradient w.r.t. log_sigma
589            for i in 0..dim {
590                let orig = log_sigma[i];
591                log_sigma[i] = orig + fd_eps;
592                let elbo_plus = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
593                log_sigma[i] = orig - fd_eps;
594                let elbo_minus = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
595                log_sigma[i] = orig;
596                full_grad[dim + i] = (elbo_plus - elbo_minus) / (2.0 * fd_eps);
597            }
598
599            // Gradient w.r.t. flow parameters
600            let flow_params = flow.get_params();
601            for i in 0..n_flow {
602                let mut fp_plus = flow_params.clone();
603                fp_plus[i] += fd_eps;
604                flow.set_params(&fp_plus)?;
605                let elbo_plus = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
606
607                let mut fp_minus = flow_params.clone();
608                fp_minus[i] -= fd_eps;
609                flow.set_params(&fp_minus)?;
610                let elbo_minus = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
611
612                flow.set_params(&flow_params)?;
613                full_grad[n_base + i] = (elbo_plus - elbo_minus) / (2.0 * fd_eps);
614            }
615
616            // Adam update
617            let direction = adam.update(&full_grad);
618            let lr = self.config.learning_rate;
619
620            for i in 0..dim {
621                mu[i] += lr * direction[i];
622                log_sigma[i] += lr * direction[dim + i];
623                log_sigma[i] = log_sigma[i].max(-10.0).min(10.0);
624            }
625
626            // Update flow parameters
627            let mut new_flow_params = flow.get_params();
628            for i in 0..n_flow {
629                new_flow_params[i] += lr * direction[n_base + i];
630                new_flow_params[i] = new_flow_params[i].max(-5.0).min(5.0);
631            }
632            flow.set_params(&new_flow_params)?;
633
634            // Check convergence
635            if elbo_history.len() >= self.config.convergence_window {
636                let n = elbo_history.len();
637                let w = self.config.convergence_window;
638                let hw = w / 2;
639                let recent_avg: f64 = elbo_history[n - hw..n].iter().sum::<f64>() / hw as f64;
640                let earlier_avg: f64 = elbo_history[n - w..n - hw].iter().sum::<f64>() / hw as f64;
641                if (recent_avg - earlier_avg).abs() < self.config.tolerance {
642                    converged = true;
643                    break;
644                }
645            }
646        }
647
648        // Generate posterior samples through the flow
649        let n_posterior_samples = 100;
650        let mut samples = Vec::with_capacity(n_posterior_samples);
651        for s in 0..n_posterior_samples {
652            let seed = self.config.seed.wrapping_add(100000 + s as u64);
653            let epsilon = self.generate_epsilon(dim, seed);
654            let sigma = log_sigma.mapv(f64::exp);
655            let z0 = &mu + &(&sigma * &epsilon);
656            let (z_k, _) = flow.forward(&z0)?;
657            samples.push(z_k);
658        }
659
660        // Compute means and stds from samples
661        let mut mean = Array1::zeros(dim);
662        for s in &samples {
663            mean = &mean + s;
664        }
665        mean /= n_posterior_samples as f64;
666
667        let mut var = Array1::zeros(dim);
668        for s in &samples {
669            let diff = s - &mean;
670            var = &var + &(&diff * &diff);
671        }
672        var /= (n_posterior_samples - 1).max(1) as f64;
673        let std_devs = var.mapv(f64::sqrt);
674
675        let iterations = elbo_history.len();
676        Ok(PosteriorResult {
677            means: mean,
678            std_devs,
679            elbo_history: elbo_history.clone(),
680            iterations,
681            converged,
682            samples: Some(samples),
683        })
684    }
685}
686
687impl FlowVi {
688    /// Estimate ELBO using Monte Carlo samples through the flow
689    fn estimate_elbo<F>(
690        &self,
691        mu: &Array1<f64>,
692        log_sigma: &Array1<f64>,
693        flow: &NormalizingFlowChain,
694        log_joint: &F,
695        iter: usize,
696    ) -> StatsResult<f64>
697    where
698        F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
699    {
700        let dim = mu.len();
701        let sigma = log_sigma.mapv(f64::exp);
702        let mut elbo_sum = 0.0;
703
704        for s in 0..self.config.num_samples {
705            let seed = self
706                .config
707                .seed
708                .wrapping_add(iter as u64 * 1000)
709                .wrapping_add(s as u64);
710            let epsilon = self.generate_epsilon(dim, seed);
711
712            // Sample from base: z_0 = mu + sigma * epsilon
713            let z0 = mu + &(&sigma * &epsilon);
714
715            // Apply flow chain
716            let (z_k, sum_log_det) = flow.forward(&z0)?;
717
718            // log q_0(z_0) = sum_i [-0.5*log(2pi) - log(sigma_i) - 0.5*(epsilon_i)^2]
719            let log_q0: f64 = (0..dim)
720                .map(|i| -0.5 * (2.0 * PI).ln() - log_sigma[i] - 0.5 * epsilon[i] * epsilon[i])
721                .sum();
722
723            // log q_K(z_K) = log q_0(z_0) - sum_k log|det J_k|
724            let _log_q_k = log_q0 - sum_log_det;
725
726            // ELBO = E[log p(x, z_K) - log q_K(z_K)]
727            //      = E[log p(x, z_K) - log q_0(z_0) + sum log|det J_k|]
728            let (log_p, _) = log_joint(&z_k)?;
729            let elbo_s = log_p - log_q0 + sum_log_det;
730            elbo_sum += elbo_s;
731        }
732
733        Ok(elbo_sum / self.config.num_samples as f64)
734    }
735}
736
737// ============================================================================
738// Tests
739// ============================================================================
740
741#[cfg(test)]
742mod tests {
743    use super::*;
744
745    /// Test: planar flow preserves volume — det(Jacobian) is nonzero
746    #[test]
747    fn test_planar_flow_volume_preservation() {
748        let layer = FlowLayer::new_planar(3, 42);
749        let z = Array1::from_vec(vec![1.0, -0.5, 0.3]);
750        let (fz, log_det) = layer.forward(&z).expect("forward should succeed");
751
752        assert_eq!(fz.len(), 3, "Output dimension should match input");
753        assert!(
754            log_det.is_finite(),
755            "Log-det-Jacobian should be finite, got {}",
756            log_det
757        );
758        // The Jacobian determinant should never be exactly zero
759        // (enforced by invertibility constraint)
760        assert!(
761            log_det.exp() > 1e-15,
762            "det(J) should be nonzero, got exp({}) = {}",
763            log_det,
764            log_det.exp()
765        );
766    }
767
768    /// Test: radial flow preserves volume
769    #[test]
770    fn test_radial_flow_volume_preservation() {
771        let layer = FlowLayer::new_radial(3, 42);
772        let z = Array1::from_vec(vec![1.0, -0.5, 0.3]);
773        let (fz, log_det) = layer.forward(&z).expect("forward should succeed");
774
775        assert_eq!(fz.len(), 3);
776        assert!(log_det.is_finite(), "Log-det should be finite");
777        assert!(log_det.exp() > 1e-15, "det(J) should be nonzero");
778    }
779
780    /// Test: flow chain application and log-det accumulation
781    #[test]
782    fn test_flow_chain_forward() {
783        let flow = NormalizingFlowChain::planar(2, 4, 42);
784        let z0 = Array1::from_vec(vec![0.5, -0.3]);
785        let (z_k, total_log_det) = flow.forward(&z0).expect("chain forward should succeed");
786
787        assert_eq!(z_k.len(), 2);
788        assert!(total_log_det.is_finite(), "Total log-det should be finite");
789
790        // Compare with applying layers individually
791        let mut z = z0.clone();
792        let mut accum = 0.0;
793        for layer in &flow.layers {
794            let (z_new, ld) = layer.forward(&z).expect("layer forward should succeed");
795            z = z_new;
796            accum += ld;
797        }
798        assert!(
799            (total_log_det - accum).abs() < 1e-10,
800            "Chain log-det ({}) should equal accumulated ({})",
801            total_log_det,
802            accum
803        );
804    }
805
806    /// Test: flow parameters roundtrip (get/set)
807    #[test]
808    fn test_flow_params_roundtrip() {
809        let mut flow = NormalizingFlowChain::mixed(3, 4, 42);
810        let params = flow.get_params();
811        let n = params.len();
812        assert!(n > 0, "Should have flow parameters");
813
814        // Perturb and restore
815        let mut perturbed = params.clone();
816        for i in 0..n {
817            perturbed[i] += 0.1;
818        }
819        flow.set_params(&perturbed).expect("set should succeed");
820        let retrieved = flow.get_params();
821        for i in 0..n {
822            assert!(
823                (retrieved[i] - perturbed[i]).abs() < 1e-10,
824                "Param {} mismatch after set",
825                i
826            );
827        }
828
829        // Restore original
830        flow.set_params(&params).expect("restore should succeed");
831        let restored = flow.get_params();
832        for i in 0..n {
833            assert!(
834                (restored[i] - params[i]).abs() < 1e-10,
835                "Param {} mismatch after restore",
836                i
837            );
838        }
839    }
840
841    /// Test: FlowVI achieves better ELBO than a baseline (no flow = mean-field only)
842    /// We check that the final ELBO with flows is at least as good as without.
843    #[test]
844    fn test_flow_vi_improves_elbo() {
845        // Target: N(2, 1)
846        let target_fn = |theta: &Array1<f64>| -> StatsResult<(f64, Array1<f64>)> {
847            let x = theta[0];
848            let log_p = -0.5 * (x - 2.0).powi(2);
849            let grad = Array1::from_vec(vec![-(x - 2.0)]);
850            Ok((log_p, grad))
851        };
852
853        // With flows
854        let flow_config = FlowViConfig {
855            flow_type: FlowType::Planar,
856            n_flow_layers: 2,
857            num_samples: 10,
858            learning_rate: 0.01,
859            max_iterations: 200,
860            tolerance: 1e-6,
861            seed: 42,
862            convergence_window: 50,
863        };
864
865        let mut flow_vi = FlowVi::new(flow_config);
866        let result = flow_vi.fit(target_fn, 1).expect("FlowVI should succeed");
867
868        // Basic sanity checks
869        assert!(!result.elbo_history.is_empty(), "Should have ELBO history");
870        let final_elbo = result
871            .elbo_history
872            .last()
873            .copied()
874            .unwrap_or(f64::NEG_INFINITY);
875        assert!(
876            final_elbo.is_finite(),
877            "Final ELBO should be finite, got {}",
878            final_elbo
879        );
880
881        // Mean should be near 2.0
882        assert!(
883            (result.means[0] - 2.0).abs() < 2.0,
884            "Mean should be near 2.0, got {}",
885            result.means[0]
886        );
887    }
888
889    /// Test: dimension mismatch error
890    #[test]
891    fn test_flow_dimension_mismatch() {
892        let layer = FlowLayer::Planar {
893            w: Array1::from_vec(vec![1.0, 0.5]),
894            u: Array1::from_vec(vec![0.3, -0.2]),
895            b: 0.1,
896        };
897        let z = Array1::from_vec(vec![1.0, 2.0, 3.0]); // wrong dim
898        let result = layer.forward(&z);
899        assert!(result.is_err(), "Should fail on dimension mismatch");
900    }
901
902    /// Test: zero dimension error for FlowVi
903    #[test]
904    fn test_flow_vi_zero_dim() {
905        let mut fv = FlowVi::new(FlowViConfig::default());
906        let result = fv.fit(|_: &Array1<f64>| Ok((0.0, Array1::zeros(0))), 0);
907        assert!(result.is_err());
908    }
909
910    /// Test: planar invertibility enforcement
911    #[test]
912    fn test_planar_invertibility() {
913        // Even with adversarial w, u, the flow should produce finite outputs
914        let w = Array1::from_vec(vec![1.0, 0.0]);
915        let u = Array1::from_vec(vec![-5.0, 0.0]); // w^T u = -5 < -1, needs correction
916        let u_hat = enforce_planar_invertibility(&w, &u);
917        let wtu_hat = w.dot(&u_hat);
918        assert!(
919            wtu_hat >= -1.0 - 1e-10,
920            "w^T u_hat should be >= -1, got {}",
921            wtu_hat
922        );
923    }
924}