scirs2_stats/variational/
svgd.rs1use crate::error::{StatsError, StatsResult};
20use scirs2_core::ndarray::Array1;
21use std::f64::consts::PI;
22
23use super::{PosteriorResult, VariationalInference};
24
25#[derive(Debug, Clone)]
31pub struct RbfKernel {
32 pub bandwidth: Option<f64>,
34}
35
36impl RbfKernel {
37 fn median_bandwidth(particles: &[Array1<f64>]) -> f64 {
40 let n = particles.len();
41 if n <= 1 {
42 return 1.0;
43 }
44
45 let mut dists_sq: Vec<f64> = Vec::with_capacity(n * (n - 1) / 2);
46 for i in 0..n {
47 for j in (i + 1)..n {
48 let diff = &particles[i] - &particles[j];
49 dists_sq.push(diff.dot(&diff));
50 }
51 }
52
53 if dists_sq.is_empty() {
54 return 1.0;
55 }
56
57 dists_sq.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
58 let median_sq = dists_sq[dists_sq.len() / 2];
59
60 let log_n = (n as f64).ln().max(1.0);
62 let h_sq = median_sq / log_n;
63 h_sq.max(1e-6).sqrt()
64 }
65
66 fn eval_with_grad(&self, x: &Array1<f64>, y: &Array1<f64>, h: f64) -> (f64, Array1<f64>) {
72 let diff = x - y;
73 let dist_sq = diff.dot(&diff);
74 let h_sq = h * h;
75 let k_val = (-dist_sq / (2.0 * h_sq)).exp();
76 let grad_x = &diff * (-k_val / h_sq);
77 (k_val, grad_x)
78 }
79}
80
81#[derive(Debug, Clone)]
86struct SvgdAdamState {
87 m: Vec<Array1<f64>>,
88 v: Vec<Array1<f64>>,
89 t: usize,
90 beta1: f64,
91 beta2: f64,
92 epsilon: f64,
93}
94
95impl SvgdAdamState {
96 fn new(n_particles: usize, dim: usize) -> Self {
97 Self {
98 m: vec![Array1::zeros(dim); n_particles],
99 v: vec![Array1::zeros(dim); n_particles],
100 t: 0,
101 beta1: 0.9,
102 beta2: 0.999,
103 epsilon: 1e-8,
104 }
105 }
106
107 fn update(&mut self, grads: &[Array1<f64>]) -> Vec<Array1<f64>> {
109 self.t += 1;
110 let n = grads.len();
111 let mut directions = Vec::with_capacity(n);
112
113 for i in 0..n {
114 let dim = grads[i].len();
115 let mut dir = Array1::zeros(dim);
116 for j in 0..dim {
117 self.m[i][j] = self.beta1 * self.m[i][j] + (1.0 - self.beta1) * grads[i][j];
118 self.v[i][j] =
119 self.beta2 * self.v[i][j] + (1.0 - self.beta2) * grads[i][j] * grads[i][j];
120 let m_hat = self.m[i][j] / (1.0 - self.beta1.powi(self.t as i32));
121 let v_hat = self.v[i][j] / (1.0 - self.beta2.powi(self.t as i32));
122 dir[j] = m_hat / (v_hat.sqrt() + self.epsilon);
123 }
124 directions.push(dir);
125 }
126
127 directions
128 }
129}
130
131#[derive(Debug, Clone)]
137pub struct SvgdConfig {
138 pub num_particles: usize,
140 pub step_size: f64,
142 pub max_iterations: usize,
144 pub tolerance: f64,
146 pub kernel_bandwidth: Option<f64>,
148 pub seed: u64,
150 pub init_spread: f64,
152 pub use_adam: bool,
154}
155
156impl Default for SvgdConfig {
157 fn default() -> Self {
158 Self {
159 num_particles: 100,
160 step_size: 0.1,
161 max_iterations: 1000,
162 tolerance: 1e-4,
163 kernel_bandwidth: None,
164 seed: 42,
165 init_spread: 1.0,
166 use_adam: true,
167 }
168 }
169}
170
171#[derive(Debug, Clone)]
195pub struct Svgd {
196 pub config: SvgdConfig,
198 kernel: RbfKernel,
200}
201
202impl Svgd {
203 pub fn new(config: SvgdConfig) -> Self {
205 let kernel = RbfKernel {
206 bandwidth: config.kernel_bandwidth,
207 };
208 Self { config, kernel }
209 }
210
211 fn init_particles(&self, dim: usize) -> Vec<Array1<f64>> {
213 let n = self.config.num_particles;
214 let golden = 1.618033988749895_f64;
215 let plastic = 1.324717957244746_f64;
216
217 (0..n)
218 .map(|i| {
219 Array1::from_shape_fn(dim, |d| {
220 let seed = self.config.seed.wrapping_add(i as u64 * 1000 + d as u64);
221 let u1 = ((seed as f64 * golden + d as f64 * plastic) % 1.0).abs();
222 let u2 = ((seed as f64 * plastic + d as f64 * golden + 0.5) % 1.0).abs();
223 let u1 = u1.max(1e-10).min(1.0 - 1e-10);
224 let u2 = u2.max(1e-10).min(1.0 - 1e-10);
225 let r = (-2.0 * u1.ln()).sqrt();
226 r * (2.0 * PI * u2).cos() * self.config.init_spread
227 })
228 })
229 .collect()
230 }
231
232 fn compute_phi_star<F>(
237 &self,
238 particles: &[Array1<f64>],
239 log_joint: &F,
240 bandwidth: f64,
241 ) -> StatsResult<Vec<Array1<f64>>>
242 where
243 F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
244 {
245 let n = particles.len();
246 let dim = particles[0].len();
247
248 let mut grad_log_p: Vec<Array1<f64>> = Vec::with_capacity(n);
250 for particle in particles {
251 let (_log_p, grad) = log_joint(particle)?;
252 grad_log_p.push(grad);
253 }
254
255 let mut phi_star: Vec<Array1<f64>> = vec![Array1::zeros(dim); n];
257
258 for i in 0..n {
259 for j in 0..n {
260 let (k_val, grad_k_j) =
261 self.kernel
262 .eval_with_grad(&particles[j], &particles[i], bandwidth);
263
264 for d in 0..dim {
266 phi_star[i][d] += k_val * grad_log_p[j][d];
267 }
268
269 for d in 0..dim {
272 phi_star[i][d] += grad_k_j[d];
273 }
274 }
275
276 phi_star[i] /= n as f64;
278 }
279
280 Ok(phi_star)
281 }
282
283 fn estimate_elbo<F>(
286 &self,
287 particles: &[Array1<f64>],
288 log_joint: &F,
289 bandwidth: f64,
290 ) -> StatsResult<f64>
291 where
292 F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
293 {
294 let n = particles.len();
295 let dim = particles[0].len();
296
297 let mut avg_log_p = 0.0;
299 for particle in particles {
300 let (log_p, _) = log_joint(particle)?;
301 avg_log_p += log_p;
302 }
303 avg_log_p /= n as f64;
304
305 let mut entropy_est = 0.0;
308 for i in 0..n {
309 let mut kde_sum = 0.0;
310 for j in 0..n {
311 let diff = &particles[i] - &particles[j];
312 let dist_sq = diff.dot(&diff);
313 kde_sum += (-dist_sq / (2.0 * bandwidth * bandwidth)).exp();
314 }
315 let norm_const = (2.0 * PI * bandwidth * bandwidth).powf(dim as f64 / 2.0);
316 let density = kde_sum / (n as f64 * norm_const);
317 if density > 1e-300 {
318 entropy_est -= density.ln();
319 }
320 }
321 entropy_est /= n as f64;
322
323 Ok(avg_log_p + entropy_est)
324 }
325}
326
327impl VariationalInference for Svgd {
328 fn fit<F>(&mut self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
329 where
330 F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
331 {
332 if dim == 0 {
333 return Err(StatsError::InvalidArgument(
334 "Dimension must be at least 1".to_string(),
335 ));
336 }
337 if self.config.num_particles < 2 {
338 return Err(StatsError::InvalidArgument(
339 "num_particles must be at least 2".to_string(),
340 ));
341 }
342 if self.config.step_size <= 0.0 {
343 return Err(StatsError::InvalidArgument(
344 "step_size must be positive".to_string(),
345 ));
346 }
347
348 let n = self.config.num_particles;
349 let mut particles = self.init_particles(dim);
350 let mut elbo_history = Vec::with_capacity(self.config.max_iterations);
351 let mut converged = false;
352
353 let mut adam = if self.config.use_adam {
354 Some(SvgdAdamState::new(n, dim))
355 } else {
356 None
357 };
358
359 for _iter in 0..self.config.max_iterations {
360 let bandwidth = self
362 .config
363 .kernel_bandwidth
364 .unwrap_or_else(|| RbfKernel::median_bandwidth(&particles));
365
366 let phi_star = self.compute_phi_star(&particles, &log_joint, bandwidth)?;
368
369 let updates: Vec<Array1<f64>> = if let Some(ref mut adam_state) = adam {
371 let directions = adam_state.update(&phi_star);
372 directions
373 .into_iter()
374 .map(|d| &d * self.config.step_size)
375 .collect()
376 } else {
377 phi_star
378 .iter()
379 .map(|phi| phi * self.config.step_size)
380 .collect()
381 };
382
383 let avg_update_norm: f64 =
385 updates.iter().map(|u| u.dot(u).sqrt()).sum::<f64>() / n as f64;
386
387 for i in 0..n {
388 particles[i] = &particles[i] + &updates[i];
389 }
390
391 if _iter % 10 == 0 || _iter == self.config.max_iterations - 1 {
393 let elbo = self.estimate_elbo(&particles, &log_joint, bandwidth)?;
394 elbo_history.push(elbo);
395 }
396
397 if avg_update_norm < self.config.tolerance {
399 converged = true;
400 break;
401 }
402 }
403
404 let mut mean = Array1::zeros(dim);
406 for p in &particles {
407 mean = &mean + p;
408 }
409 mean /= n as f64;
410
411 let mut var = Array1::zeros(dim);
412 for p in &particles {
413 let diff = p - &mean;
414 var = &var + &(&diff * &diff);
415 }
416 var /= (n - 1) as f64;
417 let std_devs = var.mapv(f64::sqrt);
418
419 Ok(PosteriorResult {
420 means: mean,
421 std_devs,
422 elbo_history,
423 iterations: self.config.max_iterations,
424 converged,
425 samples: Some(particles),
426 })
427 }
428}
429
430#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
440 fn test_svgd_gaussian_convergence() {
441 let target_mean = 2.0_f64;
442 let target_var = 0.5_f64;
443
444 let config = SvgdConfig {
445 num_particles: 50,
446 step_size: 0.1,
447 max_iterations: 500,
448 tolerance: 1e-5,
449 seed: 42,
450 init_spread: 2.0,
451 use_adam: true,
452 ..Default::default()
453 };
454
455 let mut svgd = Svgd::new(config);
456 let result = svgd
457 .fit(
458 move |theta: &Array1<f64>| {
459 let x = theta[0];
460 let log_p = -0.5 * (x - target_mean).powi(2) / target_var;
461 let grad = Array1::from_vec(vec![-(x - target_mean) / target_var]);
462 Ok((log_p, grad))
463 },
464 1,
465 )
466 .expect("SVGD should succeed");
467
468 assert!(
469 (result.means[0] - target_mean).abs() < 0.5,
470 "Mean should be near {}, got {}",
471 target_mean,
472 result.means[0]
473 );
474 assert!(
475 result.samples.is_some(),
476 "SVGD should return posterior samples"
477 );
478 }
479
480 #[ignore = "slow: SVGD convergence test can exceed timeout"]
482 #[test]
483 fn test_svgd_bimodal() {
484 let config = SvgdConfig {
486 num_particles: 100,
487 step_size: 0.05,
488 max_iterations: 1000,
489 tolerance: 1e-6,
490 seed: 123,
491 init_spread: 5.0,
492 use_adam: true,
493 ..Default::default()
494 };
495
496 let mut svgd = Svgd::new(config);
497 let result = svgd
498 .fit(
499 |theta: &Array1<f64>| {
500 let x = theta[0];
501 let var = 0.5;
502 let log_comp1 = -0.5 * (x + 3.0).powi(2) / var;
504 let log_comp2 = -0.5 * (x - 3.0).powi(2) / var;
505 let max_log = log_comp1.max(log_comp2);
506 let log_p =
507 max_log + ((log_comp1 - max_log).exp() + (log_comp2 - max_log).exp()).ln();
508
509 let w1 = (log_comp1 - max_log).exp();
511 let w2 = (log_comp2 - max_log).exp();
512 let total = w1 + w2;
513 let grad_x = (w1 * (-(x + 3.0) / var) + w2 * (-(x - 3.0) / var)) / total;
514 Ok((log_p, Array1::from_vec(vec![grad_x])))
515 },
516 1,
517 )
518 .expect("SVGD should succeed");
519
520 let samples = result.samples.as_ref().expect("should have samples");
521
522 let left_count = samples.iter().filter(|p| p[0] < 0.0).count();
524 let right_count = samples.iter().filter(|p| p[0] >= 0.0).count();
525 assert!(
526 left_count > 5 && right_count > 5,
527 "Particles should spread across both modes: left={}, right={}",
528 left_count,
529 right_count
530 );
531 }
532
533 #[test]
536 fn test_svgd_repulsive_prevents_collapse() {
537 let config = SvgdConfig {
538 num_particles: 30,
539 step_size: 0.05,
540 max_iterations: 200,
541 tolerance: 1e-8,
542 seed: 77,
543 init_spread: 2.0,
544 use_adam: true,
545 ..Default::default()
546 };
547
548 let mut svgd = Svgd::new(config);
549 let result = svgd
550 .fit(
551 |theta: &Array1<f64>| {
552 let x = theta[0];
554 let var = 0.01;
555 let log_p = -0.5 * x * x / var;
556 let grad = Array1::from_vec(vec![-x / var]);
557 Ok((log_p, grad))
558 },
559 1,
560 )
561 .expect("SVGD should succeed");
562
563 let samples = result.samples.as_ref().expect("should have samples");
564
565 let mean = result.means[0];
567 let var: f64 =
568 samples.iter().map(|p| (p[0] - mean).powi(2)).sum::<f64>() / samples.len() as f64;
569
570 assert!(
572 var > 1e-10,
573 "Particle variance {} should be nonzero (repulsion prevents collapse)",
574 var
575 );
576 }
577
578 #[ignore = "slow: SVGD may exceed timeout on slow machines"]
580 #[test]
581 fn test_svgd_2d_gaussian() {
582 let config = SvgdConfig {
583 num_particles: 80,
584 step_size: 0.1,
585 max_iterations: 500,
586 tolerance: 1e-5,
587 seed: 55,
588 init_spread: 3.0,
589 use_adam: true,
590 ..Default::default()
591 };
592
593 let mut svgd = Svgd::new(config);
594 let result = svgd
595 .fit(
596 |theta: &Array1<f64>| {
597 let d0 = theta[0] - 1.0;
599 let d1 = theta[1] + 1.0;
600 let log_p = -0.5 * (d0 * d0 + d1 * d1);
601 let grad = Array1::from_vec(vec![-d0, -d1]);
602 Ok((log_p, grad))
603 },
604 2,
605 )
606 .expect("SVGD should succeed");
607
608 assert!(
609 (result.means[0] - 1.0).abs() < 1.0,
610 "Mean[0] should be near 1.0, got {}",
611 result.means[0]
612 );
613 assert!(
614 (result.means[1] - (-1.0)).abs() < 1.0,
615 "Mean[1] should be near -1.0, got {}",
616 result.means[1]
617 );
618 }
619
620 #[test]
622 fn test_svgd_validation() {
623 let mut svgd = Svgd::new(SvgdConfig {
624 num_particles: 1, ..Default::default()
626 });
627 let result = svgd.fit(|_: &Array1<f64>| Ok((0.0, Array1::zeros(1))), 1);
628 assert!(result.is_err());
629 }
630
631 #[test]
633 fn test_median_bandwidth() {
634 let particles = vec![
635 Array1::from_vec(vec![0.0]),
636 Array1::from_vec(vec![1.0]),
637 Array1::from_vec(vec![2.0]),
638 Array1::from_vec(vec![3.0]),
639 Array1::from_vec(vec![4.0]),
640 ];
641 let h = RbfKernel::median_bandwidth(&particles);
642 assert!(h > 0.0, "Bandwidth should be positive");
643 assert!(h < 10.0, "Bandwidth should be reasonable, got {}", h);
644 }
645}