1use crate::error::{StatsError, StatsResult};
15use scirs2_core::ndarray::Array1;
16use std::f64::consts::PI;
17
18use super::{PosteriorResult, VariationalInference};
19
20#[derive(Debug, Clone)]
26pub enum FlowLayer {
27 Planar {
29 w: Array1<f64>,
31 u: Array1<f64>,
33 b: f64,
35 },
36 Radial {
38 z0: Array1<f64>,
40 alpha: f64,
42 beta: f64,
44 },
45}
46
47impl FlowLayer {
48 pub fn new_planar(dim: usize, seed: u64) -> Self {
50 let golden = 1.618033988749895_f64;
51 let plastic = 1.324717957244746_f64;
52
53 let w = Array1::from_shape_fn(dim, |i| {
54 let u1 = ((seed as f64 * golden + i as f64 * plastic + 0.3) % 1.0)
55 .abs()
56 .max(1e-10)
57 .min(1.0 - 1e-10);
58 let u2 = ((seed as f64 * plastic + i as f64 * golden + 0.7) % 1.0)
59 .abs()
60 .max(1e-10)
61 .min(1.0 - 1e-10);
62 let r = (-2.0 * u1.ln()).sqrt();
63 r * (2.0 * PI * u2).cos() * 0.1
64 });
65
66 let u = Array1::from_shape_fn(dim, |i| {
67 let u1 = (((seed + 100) as f64 * golden + i as f64 * plastic + 0.1) % 1.0)
68 .abs()
69 .max(1e-10)
70 .min(1.0 - 1e-10);
71 let u2 = (((seed + 100) as f64 * plastic + i as f64 * golden + 0.9) % 1.0)
72 .abs()
73 .max(1e-10)
74 .min(1.0 - 1e-10);
75 let r = (-2.0 * u1.ln()).sqrt();
76 r * (2.0 * PI * u2).cos() * 0.1
77 });
78
79 let b_val = {
80 let u1 = ((seed as f64 * 0.37 + 0.5) % 1.0)
81 .abs()
82 .max(1e-10)
83 .min(1.0 - 1e-10);
84 let u2 = ((seed as f64 * 0.73 + 0.5) % 1.0)
85 .abs()
86 .max(1e-10)
87 .min(1.0 - 1e-10);
88 let r = (-2.0 * u1.ln()).sqrt();
89 r * (2.0 * PI * u2).cos() * 0.1
90 };
91
92 FlowLayer::Planar { w, u, b: b_val }
93 }
94
95 pub fn new_radial(dim: usize, seed: u64) -> Self {
97 let golden = 1.618033988749895_f64;
98 let plastic = 1.324717957244746_f64;
99
100 let z0 = Array1::from_shape_fn(dim, |i| {
101 let u1 = (((seed + 200) as f64 * golden + i as f64 * plastic + 0.2) % 1.0)
102 .abs()
103 .max(1e-10)
104 .min(1.0 - 1e-10);
105 let u2 = (((seed + 200) as f64 * plastic + i as f64 * golden + 0.8) % 1.0)
106 .abs()
107 .max(1e-10)
108 .min(1.0 - 1e-10);
109 let r = (-2.0 * u1.ln()).sqrt();
110 r * (2.0 * PI * u2).cos() * 0.1
111 });
112
113 FlowLayer::Radial {
114 z0,
115 alpha: 1.0,
116 beta: 0.1,
117 }
118 }
119
120 pub fn forward(&self, z: &Array1<f64>) -> StatsResult<(Array1<f64>, f64)> {
124 match self {
125 FlowLayer::Planar { w, u, b } => {
126 let dim = z.len();
127 if w.len() != dim || u.len() != dim {
128 return Err(StatsError::DimensionMismatch(format!(
129 "Flow dimension mismatch: z={}, w={}, u={}",
130 dim,
131 w.len(),
132 u.len()
133 )));
134 }
135
136 let u_hat = enforce_planar_invertibility(w, u);
139
140 let wtz = w.dot(z) + b;
141 let tanh_wtz = wtz.tanh();
142
143 let fz = z + &(&u_hat * tanh_wtz);
145
146 let dtanh = 1.0 - tanh_wtz * tanh_wtz;
148 let psi = w * dtanh;
149 let det_term = 1.0 + u_hat.dot(&psi);
150
151 let log_det = det_term.abs().max(1e-15).ln();
152
153 Ok((fz, log_det))
154 }
155 FlowLayer::Radial { z0, alpha, beta } => {
156 let dim = z.len();
157 if z0.len() != dim {
158 return Err(StatsError::DimensionMismatch(format!(
159 "Flow dimension mismatch: z={}, z0={}",
160 dim,
161 z0.len()
162 )));
163 }
164
165 let diff = z - z0;
166 let r = diff.dot(&diff).sqrt().max(1e-10);
167 let alpha_pos = alpha.abs().max(1e-6);
168
169 let beta_hat = -alpha_pos + softplus(*beta + alpha_pos);
171
172 let h = 1.0 / (alpha_pos + r);
173 let h_prime = -1.0 / ((alpha_pos + r) * (alpha_pos + r));
174
175 let fz = z + &(&diff * (beta_hat * h));
177
178 let d = dim as f64;
181 let term1 = 1.0 + beta_hat * h;
182 let term2 = 1.0 + beta_hat * h + beta_hat * h_prime * r;
183
184 let log_det = (d - 1.0) * term1.abs().max(1e-15).ln() + term2.abs().max(1e-15).ln();
185
186 Ok((fz, log_det))
187 }
188 }
189 }
190
191 pub fn n_params(&self) -> usize {
193 match self {
194 FlowLayer::Planar { w, u, .. } => w.len() + u.len() + 1,
195 FlowLayer::Radial { z0, .. } => z0.len() + 2,
196 }
197 }
198
199 pub fn get_params(&self) -> Array1<f64> {
201 match self {
202 FlowLayer::Planar { w, u, b } => {
203 let dim = w.len();
204 let mut params = Array1::zeros(2 * dim + 1);
205 for i in 0..dim {
206 params[i] = w[i];
207 params[dim + i] = u[i];
208 }
209 params[2 * dim] = *b;
210 params
211 }
212 FlowLayer::Radial { z0, alpha, beta } => {
213 let dim = z0.len();
214 let mut params = Array1::zeros(dim + 2);
215 for i in 0..dim {
216 params[i] = z0[i];
217 }
218 params[dim] = *alpha;
219 params[dim + 1] = *beta;
220 params
221 }
222 }
223 }
224
225 pub fn set_params(&mut self, params: &Array1<f64>) -> StatsResult<()> {
227 match self {
228 FlowLayer::Planar { w, u, b } => {
229 let dim = w.len();
230 if params.len() != 2 * dim + 1 {
231 return Err(StatsError::DimensionMismatch(format!(
232 "Expected {} params, got {}",
233 2 * dim + 1,
234 params.len()
235 )));
236 }
237 for i in 0..dim {
238 w[i] = params[i];
239 u[i] = params[dim + i];
240 }
241 *b = params[2 * dim];
242 Ok(())
243 }
244 FlowLayer::Radial { z0, alpha, beta } => {
245 let dim = z0.len();
246 if params.len() != dim + 2 {
247 return Err(StatsError::DimensionMismatch(format!(
248 "Expected {} params, got {}",
249 dim + 2,
250 params.len()
251 )));
252 }
253 for i in 0..dim {
254 z0[i] = params[i];
255 }
256 *alpha = params[dim];
257 *beta = params[dim + 1];
258 Ok(())
259 }
260 }
261 }
262}
263
264fn enforce_planar_invertibility(w: &Array1<f64>, u: &Array1<f64>) -> Array1<f64> {
267 let wtu = w.dot(u);
268 let w_norm_sq = w.dot(w);
269 if w_norm_sq < 1e-15 {
270 return u.clone();
271 }
272 let m_wtu = -1.0 + softplus(wtu);
274 if (m_wtu - wtu).abs() < 1e-15 {
275 return u.clone();
276 }
277 u + &(w * ((m_wtu - wtu) / w_norm_sq))
278}
279
280fn softplus(x: f64) -> f64 {
282 if x > 20.0 {
283 x
284 } else if x < -20.0 {
285 x.exp()
286 } else {
287 (1.0 + x.exp()).ln()
288 }
289}
290
291#[derive(Debug, Clone)]
297pub struct NormalizingFlowChain {
298 pub layers: Vec<FlowLayer>,
300}
301
302impl NormalizingFlowChain {
303 pub fn new(layers: Vec<FlowLayer>) -> Self {
305 Self { layers }
306 }
307
308 pub fn planar(dim: usize, n_layers: usize, seed: u64) -> Self {
310 let layers = (0..n_layers)
311 .map(|i| FlowLayer::new_planar(dim, seed + i as u64 * 7))
312 .collect();
313 Self { layers }
314 }
315
316 pub fn radial(dim: usize, n_layers: usize, seed: u64) -> Self {
318 let layers = (0..n_layers)
319 .map(|i| FlowLayer::new_radial(dim, seed + i as u64 * 11))
320 .collect();
321 Self { layers }
322 }
323
324 pub fn mixed(dim: usize, n_layers: usize, seed: u64) -> Self {
326 let layers = (0..n_layers)
327 .map(|i| {
328 if i % 2 == 0 {
329 FlowLayer::new_planar(dim, seed + i as u64 * 13)
330 } else {
331 FlowLayer::new_radial(dim, seed + i as u64 * 17)
332 }
333 })
334 .collect();
335 Self { layers }
336 }
337
338 pub fn forward(&self, z0: &Array1<f64>) -> StatsResult<(Array1<f64>, f64)> {
342 let mut z = z0.clone();
343 let mut total_log_det = 0.0;
344
345 for layer in &self.layers {
346 let (z_new, log_det) = layer.forward(&z)?;
347 z = z_new;
348 total_log_det += log_det;
349 }
350
351 Ok((z, total_log_det))
352 }
353
354 pub fn n_params(&self) -> usize {
356 self.layers.iter().map(|l| l.n_params()).sum()
357 }
358
359 pub fn get_params(&self) -> Array1<f64> {
361 let total = self.n_params();
362 let mut params = Array1::zeros(total);
363 let mut offset = 0;
364 for layer in &self.layers {
365 let lp = layer.get_params();
366 let n = lp.len();
367 for i in 0..n {
368 params[offset + i] = lp[i];
369 }
370 offset += n;
371 }
372 params
373 }
374
375 pub fn set_params(&mut self, params: &Array1<f64>) -> StatsResult<()> {
377 let total = self.n_params();
378 if params.len() != total {
379 return Err(StatsError::DimensionMismatch(format!(
380 "Expected {} total flow params, got {}",
381 total,
382 params.len()
383 )));
384 }
385 let mut offset = 0;
386 for layer in &mut self.layers {
387 let n = layer.n_params();
388 let lp = Array1::from_shape_fn(n, |i| params[offset + i]);
389 layer.set_params(&lp)?;
390 offset += n;
391 }
392 Ok(())
393 }
394}
395
396#[derive(Debug, Clone)]
402pub struct FlowViConfig {
403 pub flow_type: FlowType,
405 pub n_flow_layers: usize,
407 pub num_samples: usize,
409 pub learning_rate: f64,
411 pub max_iterations: usize,
413 pub tolerance: f64,
415 pub seed: u64,
417 pub convergence_window: usize,
419}
420
421#[derive(Debug, Clone, Copy)]
423pub enum FlowType {
424 Planar,
426 Radial,
428 Mixed,
430}
431
432impl Default for FlowViConfig {
433 fn default() -> Self {
434 Self {
435 flow_type: FlowType::Planar,
436 n_flow_layers: 4,
437 num_samples: 10,
438 learning_rate: 0.01,
439 max_iterations: 5000,
440 tolerance: 1e-4,
441 seed: 42,
442 convergence_window: 50,
443 }
444 }
445}
446
447#[derive(Debug, Clone)]
458pub struct FlowVi {
459 pub config: FlowViConfig,
461}
462
463impl FlowVi {
464 pub fn new(config: FlowViConfig) -> Self {
466 Self { config }
467 }
468
469 fn generate_epsilon(&self, dim: usize, seed: u64) -> Array1<f64> {
471 let golden = 1.618033988749895_f64;
472 let plastic = 1.324717957244746_f64;
473 Array1::from_shape_fn(dim, |i| {
474 let u1 = ((seed as f64 * golden + i as f64 * plastic) % 1.0)
475 .abs()
476 .max(1e-10)
477 .min(1.0 - 1e-10);
478 let u2 = ((seed as f64 * plastic + i as f64 * golden) % 1.0)
479 .abs()
480 .max(1e-10)
481 .min(1.0 - 1e-10);
482 let r = (-2.0 * u1.ln()).sqrt();
483 r * (2.0 * PI * u2).cos()
484 })
485 }
486}
487
488#[derive(Debug, Clone)]
490struct FlowAdamState {
491 m: Array1<f64>,
492 v: Array1<f64>,
493 t: usize,
494 beta1: f64,
495 beta2: f64,
496 epsilon: f64,
497}
498
499impl FlowAdamState {
500 fn new(n: usize) -> Self {
501 Self {
502 m: Array1::zeros(n),
503 v: Array1::zeros(n),
504 t: 0,
505 beta1: 0.9,
506 beta2: 0.999,
507 epsilon: 1e-8,
508 }
509 }
510
511 fn update(&mut self, grad: &Array1<f64>) -> Array1<f64> {
512 self.t += 1;
513 let n = grad.len();
514 let mut dir = Array1::zeros(n);
515 for i in 0..n {
516 self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * grad[i];
517 self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * grad[i] * grad[i];
518 let m_hat = self.m[i] / (1.0 - self.beta1.powi(self.t as i32));
519 let v_hat = self.v[i] / (1.0 - self.beta2.powi(self.t as i32));
520 dir[i] = m_hat / (v_hat.sqrt() + self.epsilon);
521 }
522 dir
523 }
524}
525
526impl VariationalInference for FlowVi {
527 fn fit<F>(&mut self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
528 where
529 F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
530 {
531 if dim == 0 {
532 return Err(StatsError::InvalidArgument(
533 "Dimension must be at least 1".to_string(),
534 ));
535 }
536 if self.config.n_flow_layers == 0 {
537 return Err(StatsError::InvalidArgument(
538 "n_flow_layers must be at least 1".to_string(),
539 ));
540 }
541
542 let mut mu = Array1::zeros(dim);
544 let mut log_sigma = Array1::zeros(dim);
545
546 let mut flow = match self.config.flow_type {
548 FlowType::Planar => {
549 NormalizingFlowChain::planar(dim, self.config.n_flow_layers, self.config.seed)
550 }
551 FlowType::Radial => {
552 NormalizingFlowChain::radial(dim, self.config.n_flow_layers, self.config.seed)
553 }
554 FlowType::Mixed => {
555 NormalizingFlowChain::mixed(dim, self.config.n_flow_layers, self.config.seed)
556 }
557 };
558
559 let n_base = 2 * dim;
561 let n_flow = flow.n_params();
562 let n_total = n_base + n_flow;
563 let fd_eps = 1e-4;
564
565 let mut adam = FlowAdamState::new(n_total);
566 let mut elbo_history = Vec::with_capacity(self.config.max_iterations);
567 let mut converged = false;
568
569 for iter in 0..self.config.max_iterations {
570 let elbo = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
572 elbo_history.push(elbo);
573
574 let mut full_grad = Array1::zeros(n_total);
576
577 for i in 0..dim {
579 let orig = mu[i];
580 mu[i] = orig + fd_eps;
581 let elbo_plus = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
582 mu[i] = orig - fd_eps;
583 let elbo_minus = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
584 mu[i] = orig;
585 full_grad[i] = (elbo_plus - elbo_minus) / (2.0 * fd_eps);
586 }
587
588 for i in 0..dim {
590 let orig = log_sigma[i];
591 log_sigma[i] = orig + fd_eps;
592 let elbo_plus = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
593 log_sigma[i] = orig - fd_eps;
594 let elbo_minus = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
595 log_sigma[i] = orig;
596 full_grad[dim + i] = (elbo_plus - elbo_minus) / (2.0 * fd_eps);
597 }
598
599 let flow_params = flow.get_params();
601 for i in 0..n_flow {
602 let mut fp_plus = flow_params.clone();
603 fp_plus[i] += fd_eps;
604 flow.set_params(&fp_plus)?;
605 let elbo_plus = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
606
607 let mut fp_minus = flow_params.clone();
608 fp_minus[i] -= fd_eps;
609 flow.set_params(&fp_minus)?;
610 let elbo_minus = self.estimate_elbo(&mu, &log_sigma, &flow, &log_joint, iter)?;
611
612 flow.set_params(&flow_params)?;
613 full_grad[n_base + i] = (elbo_plus - elbo_minus) / (2.0 * fd_eps);
614 }
615
616 let direction = adam.update(&full_grad);
618 let lr = self.config.learning_rate;
619
620 for i in 0..dim {
621 mu[i] += lr * direction[i];
622 log_sigma[i] += lr * direction[dim + i];
623 log_sigma[i] = log_sigma[i].max(-10.0).min(10.0);
624 }
625
626 let mut new_flow_params = flow.get_params();
628 for i in 0..n_flow {
629 new_flow_params[i] += lr * direction[n_base + i];
630 new_flow_params[i] = new_flow_params[i].max(-5.0).min(5.0);
631 }
632 flow.set_params(&new_flow_params)?;
633
634 if elbo_history.len() >= self.config.convergence_window {
636 let n = elbo_history.len();
637 let w = self.config.convergence_window;
638 let hw = w / 2;
639 let recent_avg: f64 = elbo_history[n - hw..n].iter().sum::<f64>() / hw as f64;
640 let earlier_avg: f64 = elbo_history[n - w..n - hw].iter().sum::<f64>() / hw as f64;
641 if (recent_avg - earlier_avg).abs() < self.config.tolerance {
642 converged = true;
643 break;
644 }
645 }
646 }
647
648 let n_posterior_samples = 100;
650 let mut samples = Vec::with_capacity(n_posterior_samples);
651 for s in 0..n_posterior_samples {
652 let seed = self.config.seed.wrapping_add(100000 + s as u64);
653 let epsilon = self.generate_epsilon(dim, seed);
654 let sigma = log_sigma.mapv(f64::exp);
655 let z0 = &mu + &(&sigma * &epsilon);
656 let (z_k, _) = flow.forward(&z0)?;
657 samples.push(z_k);
658 }
659
660 let mut mean = Array1::zeros(dim);
662 for s in &samples {
663 mean = &mean + s;
664 }
665 mean /= n_posterior_samples as f64;
666
667 let mut var = Array1::zeros(dim);
668 for s in &samples {
669 let diff = s - &mean;
670 var = &var + &(&diff * &diff);
671 }
672 var /= (n_posterior_samples - 1).max(1) as f64;
673 let std_devs = var.mapv(f64::sqrt);
674
675 let iterations = elbo_history.len();
676 Ok(PosteriorResult {
677 means: mean,
678 std_devs,
679 elbo_history: elbo_history.clone(),
680 iterations,
681 converged,
682 samples: Some(samples),
683 })
684 }
685}
686
687impl FlowVi {
688 fn estimate_elbo<F>(
690 &self,
691 mu: &Array1<f64>,
692 log_sigma: &Array1<f64>,
693 flow: &NormalizingFlowChain,
694 log_joint: &F,
695 iter: usize,
696 ) -> StatsResult<f64>
697 where
698 F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
699 {
700 let dim = mu.len();
701 let sigma = log_sigma.mapv(f64::exp);
702 let mut elbo_sum = 0.0;
703
704 for s in 0..self.config.num_samples {
705 let seed = self
706 .config
707 .seed
708 .wrapping_add(iter as u64 * 1000)
709 .wrapping_add(s as u64);
710 let epsilon = self.generate_epsilon(dim, seed);
711
712 let z0 = mu + &(&sigma * &epsilon);
714
715 let (z_k, sum_log_det) = flow.forward(&z0)?;
717
718 let log_q0: f64 = (0..dim)
720 .map(|i| -0.5 * (2.0 * PI).ln() - log_sigma[i] - 0.5 * epsilon[i] * epsilon[i])
721 .sum();
722
723 let _log_q_k = log_q0 - sum_log_det;
725
726 let (log_p, _) = log_joint(&z_k)?;
729 let elbo_s = log_p - log_q0 + sum_log_det;
730 elbo_sum += elbo_s;
731 }
732
733 Ok(elbo_sum / self.config.num_samples as f64)
734 }
735}
736
737#[cfg(test)]
742mod tests {
743 use super::*;
744
745 #[test]
747 fn test_planar_flow_volume_preservation() {
748 let layer = FlowLayer::new_planar(3, 42);
749 let z = Array1::from_vec(vec![1.0, -0.5, 0.3]);
750 let (fz, log_det) = layer.forward(&z).expect("forward should succeed");
751
752 assert_eq!(fz.len(), 3, "Output dimension should match input");
753 assert!(
754 log_det.is_finite(),
755 "Log-det-Jacobian should be finite, got {}",
756 log_det
757 );
758 assert!(
761 log_det.exp() > 1e-15,
762 "det(J) should be nonzero, got exp({}) = {}",
763 log_det,
764 log_det.exp()
765 );
766 }
767
768 #[test]
770 fn test_radial_flow_volume_preservation() {
771 let layer = FlowLayer::new_radial(3, 42);
772 let z = Array1::from_vec(vec![1.0, -0.5, 0.3]);
773 let (fz, log_det) = layer.forward(&z).expect("forward should succeed");
774
775 assert_eq!(fz.len(), 3);
776 assert!(log_det.is_finite(), "Log-det should be finite");
777 assert!(log_det.exp() > 1e-15, "det(J) should be nonzero");
778 }
779
780 #[test]
782 fn test_flow_chain_forward() {
783 let flow = NormalizingFlowChain::planar(2, 4, 42);
784 let z0 = Array1::from_vec(vec![0.5, -0.3]);
785 let (z_k, total_log_det) = flow.forward(&z0).expect("chain forward should succeed");
786
787 assert_eq!(z_k.len(), 2);
788 assert!(total_log_det.is_finite(), "Total log-det should be finite");
789
790 let mut z = z0.clone();
792 let mut accum = 0.0;
793 for layer in &flow.layers {
794 let (z_new, ld) = layer.forward(&z).expect("layer forward should succeed");
795 z = z_new;
796 accum += ld;
797 }
798 assert!(
799 (total_log_det - accum).abs() < 1e-10,
800 "Chain log-det ({}) should equal accumulated ({})",
801 total_log_det,
802 accum
803 );
804 }
805
806 #[test]
808 fn test_flow_params_roundtrip() {
809 let mut flow = NormalizingFlowChain::mixed(3, 4, 42);
810 let params = flow.get_params();
811 let n = params.len();
812 assert!(n > 0, "Should have flow parameters");
813
814 let mut perturbed = params.clone();
816 for i in 0..n {
817 perturbed[i] += 0.1;
818 }
819 flow.set_params(&perturbed).expect("set should succeed");
820 let retrieved = flow.get_params();
821 for i in 0..n {
822 assert!(
823 (retrieved[i] - perturbed[i]).abs() < 1e-10,
824 "Param {} mismatch after set",
825 i
826 );
827 }
828
829 flow.set_params(¶ms).expect("restore should succeed");
831 let restored = flow.get_params();
832 for i in 0..n {
833 assert!(
834 (restored[i] - params[i]).abs() < 1e-10,
835 "Param {} mismatch after restore",
836 i
837 );
838 }
839 }
840
841 #[test]
844 fn test_flow_vi_improves_elbo() {
845 let target_fn = |theta: &Array1<f64>| -> StatsResult<(f64, Array1<f64>)> {
847 let x = theta[0];
848 let log_p = -0.5 * (x - 2.0).powi(2);
849 let grad = Array1::from_vec(vec![-(x - 2.0)]);
850 Ok((log_p, grad))
851 };
852
853 let flow_config = FlowViConfig {
855 flow_type: FlowType::Planar,
856 n_flow_layers: 2,
857 num_samples: 10,
858 learning_rate: 0.01,
859 max_iterations: 200,
860 tolerance: 1e-6,
861 seed: 42,
862 convergence_window: 50,
863 };
864
865 let mut flow_vi = FlowVi::new(flow_config);
866 let result = flow_vi.fit(target_fn, 1).expect("FlowVI should succeed");
867
868 assert!(!result.elbo_history.is_empty(), "Should have ELBO history");
870 let final_elbo = result
871 .elbo_history
872 .last()
873 .copied()
874 .unwrap_or(f64::NEG_INFINITY);
875 assert!(
876 final_elbo.is_finite(),
877 "Final ELBO should be finite, got {}",
878 final_elbo
879 );
880
881 assert!(
883 (result.means[0] - 2.0).abs() < 2.0,
884 "Mean should be near 2.0, got {}",
885 result.means[0]
886 );
887 }
888
889 #[test]
891 fn test_flow_dimension_mismatch() {
892 let layer = FlowLayer::Planar {
893 w: Array1::from_vec(vec![1.0, 0.5]),
894 u: Array1::from_vec(vec![0.3, -0.2]),
895 b: 0.1,
896 };
897 let z = Array1::from_vec(vec![1.0, 2.0, 3.0]); let result = layer.forward(&z);
899 assert!(result.is_err(), "Should fail on dimension mismatch");
900 }
901
902 #[test]
904 fn test_flow_vi_zero_dim() {
905 let mut fv = FlowVi::new(FlowViConfig::default());
906 let result = fv.fit(|_: &Array1<f64>| Ok((0.0, Array1::zeros(0))), 0);
907 assert!(result.is_err());
908 }
909
910 #[test]
912 fn test_planar_invertibility() {
913 let w = Array1::from_vec(vec![1.0, 0.0]);
915 let u = Array1::from_vec(vec![-5.0, 0.0]); let u_hat = enforce_planar_invertibility(&w, &u);
917 let wtu_hat = w.dot(&u_hat);
918 assert!(
919 wtu_hat >= -1.0 - 1e-10,
920 "w^T u_hat should be >= -1, got {}",
921 wtu_hat
922 );
923 }
924}