1pub trait LogProb: Send + Sync {
17 fn log_prob(&self, theta: &[f64]) -> f64;
18}
19
20pub struct LogProbFn<F: Fn(&[f64]) -> f64 + Send + Sync> {
22 f: F,
23}
24
25impl<F: Fn(&[f64]) -> f64 + Send + Sync> LogProbFn<F> {
26 pub fn new(f: F) -> Self {
28 Self { f }
29 }
30}
31
32impl<F: Fn(&[f64]) -> f64 + Send + Sync> LogProb for LogProbFn<F> {
33 fn log_prob(&self, theta: &[f64]) -> f64 {
34 (self.f)(theta)
35 }
36}
37
38pub trait Proposal: Send + Sync {
40 fn propose(&self, current: &[f64], rng: &mut McmcRng) -> Vec<f64>;
42
43 fn log_ratio(&self, proposed: &[f64], current: &[f64]) -> f64;
47}
48
49#[derive(Debug, Clone)]
55pub struct GaussianProposal {
56 pub step_size: f64,
57}
58
59impl GaussianProposal {
60 pub fn new(step_size: f64) -> Self {
62 Self { step_size }
63 }
64}
65
66impl Proposal for GaussianProposal {
67 fn propose(&self, current: &[f64], rng: &mut McmcRng) -> Vec<f64> {
68 current
69 .iter()
70 .map(|&x| x + rng.next_normal_scaled(0.0, self.step_size))
71 .collect()
72 }
73
74 fn log_ratio(&self, _proposed: &[f64], _current: &[f64]) -> f64 {
75 0.0
76 }
77}
78
79#[derive(Debug, Clone)]
84pub struct IndependentGaussianProposal {
85 pub mean: Vec<f64>,
86 pub std: Vec<f64>,
87}
88
89impl IndependentGaussianProposal {
90 pub fn new(mean: Vec<f64>, std: Vec<f64>) -> Self {
95 debug_assert_eq!(
96 mean.len(),
97 std.len(),
98 "mean and std must have the same length"
99 );
100 Self { mean, std }
101 }
102}
103
104#[inline]
106fn log_normal_density(x: f64, mu: f64, sigma: f64) -> f64 {
107 let diff = x - mu;
108 -0.5 * (diff / sigma).powi(2) - sigma.ln()
109}
110
111impl Proposal for IndependentGaussianProposal {
112 fn propose(&self, _current: &[f64], rng: &mut McmcRng) -> Vec<f64> {
113 self.mean
114 .iter()
115 .zip(self.std.iter())
116 .map(|(&mu, &sigma)| rng.next_normal_scaled(mu, sigma))
117 .collect()
118 }
119
120 fn log_ratio(&self, proposed: &[f64], current: &[f64]) -> f64 {
121 let log_q_current: f64 = current
127 .iter()
128 .zip(self.mean.iter())
129 .zip(self.std.iter())
130 .map(|((&x, &mu), &sigma)| log_normal_density(x, mu, sigma))
131 .sum();
132 let log_q_proposed: f64 = proposed
133 .iter()
134 .zip(self.mean.iter())
135 .zip(self.std.iter())
136 .map(|((&x, &mu), &sigma)| log_normal_density(x, mu, sigma))
137 .sum();
138 log_q_current - log_q_proposed
139 }
140}
141
142#[derive(Debug, Clone)]
150pub struct McmcRng {
151 state: u64,
152}
153
154impl McmcRng {
155 pub fn new(seed: u64) -> Self {
157 let state = seed.wrapping_add(6364136223846793005);
159 Self { state }
160 }
161
162 pub fn next_u64(&mut self) -> u64 {
164 self.state = self
166 .state
167 .wrapping_mul(6364136223846793005)
168 .wrapping_add(1442695040888963407);
169 self.state
170 }
171
172 pub fn next_f64(&mut self) -> f64 {
174 (self.next_u64() >> 11) as f64 * (1.0_f64 / (1u64 << 53) as f64)
176 }
177
178 pub fn next_normal(&mut self) -> f64 {
182 let u1 = self.next_f64().max(f64::MIN_POSITIVE); let u2 = self.next_f64();
185 let r = (-2.0 * u1.ln()).sqrt();
186 let theta = std::f64::consts::TAU * u2;
187 r * theta.cos()
188 }
189
190 pub fn next_normal_scaled(&mut self, mean: f64, std: f64) -> f64 {
192 mean + std * self.next_normal()
193 }
194}
195
196#[derive(Debug, Clone)]
200pub struct McmcConfig {
201 pub n_samples: usize,
203 pub n_warmup: usize,
205 pub thin: usize,
207 pub seed: u64,
209 pub target_acceptance: f64,
211}
212
213impl Default for McmcConfig {
214 fn default() -> Self {
215 Self {
216 n_samples: 1000,
217 n_warmup: 500,
218 thin: 1,
219 seed: 42,
220 target_acceptance: 0.234,
221 }
222 }
223}
224
225impl McmcConfig {
226 pub fn new() -> Self {
228 Self::default()
229 }
230
231 pub fn n_samples(mut self, n: usize) -> Self {
233 self.n_samples = n;
234 self
235 }
236
237 pub fn n_warmup(mut self, n: usize) -> Self {
239 self.n_warmup = n;
240 self
241 }
242
243 pub fn thin(mut self, t: usize) -> Self {
245 self.thin = t;
246 self
247 }
248
249 pub fn seed(mut self, s: u64) -> Self {
251 self.seed = s;
252 self
253 }
254}
255
256#[derive(Debug, Clone)]
260pub struct ChainDiagnostics {
261 pub n_samples: usize,
263 pub acceptance_rate: f64,
265 pub mean: Vec<f64>,
267 pub variance: Vec<f64>,
269 pub effective_sample_size: Vec<f64>,
271 pub r_hat: Option<Vec<f64>>,
273}
274
275#[derive(Debug, Clone)]
277pub struct McmcResult {
278 pub samples: Vec<Vec<f64>>,
280 pub log_probs: Vec<f64>,
282 pub diagnostics: ChainDiagnostics,
284}
285
286impl McmcResult {
287 pub fn n_samples(&self) -> usize {
289 self.samples.len()
290 }
291
292 pub fn n_dims(&self) -> usize {
294 self.samples.first().map(|s| s.len()).unwrap_or(0)
295 }
296
297 pub fn marginal_samples(&self, dim: usize) -> Vec<f64> {
299 self.samples.iter().map(|s| s[dim]).collect()
300 }
301
302 pub fn posterior_mean(&self) -> Vec<f64> {
304 let n = self.n_samples();
305 if n == 0 {
306 return vec![];
307 }
308 let d = self.n_dims();
309 let mut mean = vec![0.0_f64; d];
310 for sample in &self.samples {
311 for (m, &v) in mean.iter_mut().zip(sample.iter()) {
312 *m += v;
313 }
314 }
315 mean.iter_mut().for_each(|m| *m /= n as f64);
316 mean
317 }
318
319 pub fn posterior_variance(&self) -> Vec<f64> {
321 let n = self.n_samples();
322 if n < 2 {
323 return vec![0.0; self.n_dims()];
324 }
325 let mean = self.posterior_mean();
326 let d = self.n_dims();
327 let mut var = vec![0.0_f64; d];
328 for sample in &self.samples {
329 for (v, (&x, &mu)) in var.iter_mut().zip(sample.iter().zip(mean.iter())) {
330 *v += (x - mu).powi(2);
331 }
332 }
333 var.iter_mut().for_each(|v| *v /= (n - 1) as f64);
334 var
335 }
336
337 pub fn credible_interval(&self, dim: usize, alpha: f64) -> (f64, f64) {
341 let mut marginal = self.marginal_samples(dim);
342 marginal.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
343 let n = marginal.len();
344 if n == 0 {
345 return (f64::NAN, f64::NAN);
346 }
347 let lo_idx = ((alpha / 2.0) * n as f64) as usize;
348 let hi_idx = ((1.0 - alpha / 2.0) * n as f64) as usize;
349 let lo = marginal[lo_idx.min(n - 1)];
350 let hi = marginal[hi_idx.min(n - 1)];
351 (lo, hi)
352 }
353}
354
355#[derive(Debug)]
359pub enum McmcError {
360 InvalidConfig(String),
362 DimensionMismatch,
364 NumericalError(String),
366}
367
368impl std::fmt::Display for McmcError {
369 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370 match self {
371 McmcError::InvalidConfig(msg) => write!(f, "MCMC invalid configuration: {}", msg),
372 McmcError::DimensionMismatch => {
373 write!(f, "MCMC dimension mismatch between initial state and model")
374 }
375 McmcError::NumericalError(msg) => write!(f, "MCMC numerical error: {}", msg),
376 }
377 }
378}
379
380impl std::error::Error for McmcError {}
381
382fn validate_config(config: &McmcConfig) -> Result<(), McmcError> {
386 if config.n_samples == 0 {
387 return Err(McmcError::InvalidConfig(
388 "n_samples must be > 0".to_string(),
389 ));
390 }
391 if config.thin == 0 {
392 return Err(McmcError::InvalidConfig("thin must be > 0".to_string()));
393 }
394 Ok(())
395}
396
397fn slice_stats(data: &[f64]) -> (f64, f64) {
399 let n = data.len();
400 if n == 0 {
401 return (0.0, 0.0);
402 }
403 let mean = data.iter().sum::<f64>() / n as f64;
404 let var = if n < 2 {
405 0.0
406 } else {
407 data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (n - 1) as f64
408 };
409 (mean, var)
410}
411
412pub struct MetropolisHastings<P: LogProb, Q: Proposal> {
420 log_prob: P,
421 proposal: Q,
422 config: McmcConfig,
423}
424
425impl<P: LogProb, Q: Proposal> MetropolisHastings<P, Q> {
426 pub fn new(log_prob: P, proposal: Q, config: McmcConfig) -> Self {
428 Self {
429 log_prob,
430 proposal,
431 config,
432 }
433 }
434
435 pub fn sample(&self, initial: &[f64]) -> Result<McmcResult, McmcError> {
437 validate_config(&self.config)?;
438 if initial.is_empty() {
439 return Err(McmcError::InvalidConfig(
440 "initial state must be non-empty".to_string(),
441 ));
442 }
443
444 let mut rng = McmcRng::new(self.config.seed);
445 let total_steps = self.config.n_warmup + self.config.n_samples * self.config.thin;
446
447 let mut current: Vec<f64> = initial.to_vec();
448 let mut current_lp = self.log_prob.log_prob(¤t);
449 if !current_lp.is_finite() {
450 return Err(McmcError::NumericalError(
451 "initial state has non-finite log probability".to_string(),
452 ));
453 }
454
455 let mut samples: Vec<Vec<f64>> = Vec::with_capacity(self.config.n_samples);
456 let mut log_probs: Vec<f64> = Vec::with_capacity(self.config.n_samples);
457 let mut n_accepted: usize = 0;
458 let mut step_in_sample: usize = 0; for step in 0..total_steps {
461 let proposed = self.proposal.propose(¤t, &mut rng);
462 let proposed_lp = self.log_prob.log_prob(&proposed);
463
464 let log_accept = if proposed_lp.is_finite() {
465 let log_alpha =
466 proposed_lp - current_lp + self.proposal.log_ratio(&proposed, ¤t);
467 log_alpha.min(0.0)
468 } else {
469 f64::NEG_INFINITY
470 };
471
472 let u = rng.next_f64();
473 let accepted = u.ln() < log_accept;
474
475 if accepted {
476 current = proposed;
477 current_lp = proposed_lp;
478 if step >= self.config.n_warmup {
479 n_accepted += 1;
480 }
481 }
482
483 if step >= self.config.n_warmup {
485 step_in_sample += 1;
486 if step_in_sample == self.config.thin {
487 samples.push(current.clone());
488 log_probs.push(current_lp);
489 step_in_sample = 0;
490 }
491 }
492 }
493
494 let n_post_warmup_steps = self.config.n_samples * self.config.thin;
495 let acceptance_rate = if n_post_warmup_steps > 0 {
496 n_accepted as f64 / n_post_warmup_steps as f64
497 } else {
498 0.0
499 };
500
501 let diagnostics = compute_diagnostics_with_acceptance(&samples, acceptance_rate);
502 Ok(McmcResult {
503 samples,
504 log_probs,
505 diagnostics,
506 })
507 }
508}
509
510pub struct HamiltonianMonteCarlo<P: LogProb> {
517 log_prob: P,
518 step_size: f64,
519 n_leapfrog_steps: usize,
520 config: McmcConfig,
521}
522
523impl<P: LogProb> HamiltonianMonteCarlo<P> {
524 pub fn new(log_prob: P, step_size: f64, n_leapfrog_steps: usize, config: McmcConfig) -> Self {
529 Self {
530 log_prob,
531 step_size,
532 n_leapfrog_steps,
533 config,
534 }
535 }
536
537 pub fn sample(&self, initial: &[f64]) -> Result<McmcResult, McmcError> {
539 validate_config(&self.config)?;
540 if initial.is_empty() {
541 return Err(McmcError::InvalidConfig(
542 "initial state must be non-empty".to_string(),
543 ));
544 }
545 if self.step_size <= 0.0 {
546 return Err(McmcError::InvalidConfig(
547 "step_size must be positive".to_string(),
548 ));
549 }
550 if self.n_leapfrog_steps == 0 {
551 return Err(McmcError::InvalidConfig(
552 "n_leapfrog_steps must be > 0".to_string(),
553 ));
554 }
555
556 let mut rng = McmcRng::new(self.config.seed);
557 let total_steps = self.config.n_warmup + self.config.n_samples * self.config.thin;
558 let d = initial.len();
559
560 let mut current: Vec<f64> = initial.to_vec();
561 let mut current_lp = self.log_prob.log_prob(¤t);
562 if !current_lp.is_finite() {
563 return Err(McmcError::NumericalError(
564 "initial state has non-finite log probability".to_string(),
565 ));
566 }
567
568 let mut samples: Vec<Vec<f64>> = Vec::with_capacity(self.config.n_samples);
569 let mut log_probs: Vec<f64> = Vec::with_capacity(self.config.n_samples);
570 let mut n_accepted: usize = 0;
571 let mut step_in_sample: usize = 0;
572
573 for step in 0..total_steps {
574 let momentum: Vec<f64> = (0..d).map(|_| rng.next_normal()).collect();
576
577 let ke_old: f64 = momentum.iter().map(|&r| 0.5 * r * r).sum();
579
580 let (proposed, new_momentum) = self.leapfrog(¤t, &momentum);
582
583 let proposed_lp = self.log_prob.log_prob(&proposed);
584 let ke_new: f64 = new_momentum.iter().map(|&r| 0.5 * r * r).sum();
585
586 let h_old = -current_lp + ke_old;
588 let h_new = -proposed_lp + ke_new;
589
590 let log_accept = if proposed_lp.is_finite() {
591 (h_old - h_new).min(0.0)
592 } else {
593 f64::NEG_INFINITY
594 };
595
596 let u = rng.next_f64();
597 let accepted = u.ln() < log_accept;
598
599 if accepted {
600 current = proposed;
601 current_lp = proposed_lp;
602 if step >= self.config.n_warmup {
603 n_accepted += 1;
604 }
605 }
606
607 if step >= self.config.n_warmup {
608 step_in_sample += 1;
609 if step_in_sample == self.config.thin {
610 samples.push(current.clone());
611 log_probs.push(current_lp);
612 step_in_sample = 0;
613 }
614 }
615 }
616
617 let n_post_warmup_steps = self.config.n_samples * self.config.thin;
618 let acceptance_rate = if n_post_warmup_steps > 0 {
619 n_accepted as f64 / n_post_warmup_steps as f64
620 } else {
621 0.0
622 };
623
624 let diagnostics = compute_diagnostics_with_acceptance(&samples, acceptance_rate);
625 Ok(McmcResult {
626 samples,
627 log_probs,
628 diagnostics,
629 })
630 }
631
632 fn grad_log_prob(&self, theta: &[f64], eps: f64) -> Vec<f64> {
636 let d = theta.len();
637 let mut grad = vec![0.0_f64; d];
638 let mut theta_plus = theta.to_vec();
639 let mut theta_minus = theta.to_vec();
640 for i in 0..d {
641 theta_plus[i] = theta[i] + eps;
642 theta_minus[i] = theta[i] - eps;
643 grad[i] = (self.log_prob.log_prob(&theta_plus) - self.log_prob.log_prob(&theta_minus))
644 / (2.0 * eps);
645 theta_plus[i] = theta[i];
646 theta_minus[i] = theta[i];
647 }
648 grad
649 }
650
651 fn leapfrog(&self, theta: &[f64], momentum: &[f64]) -> (Vec<f64>, Vec<f64>) {
655 let eps = self.step_size;
656 let fd_eps = 1e-5_f64;
658
659 let mut q = theta.to_vec();
660 let mut p = momentum.to_vec();
661 let d = q.len();
662
663 let grad = self.grad_log_prob(&q, fd_eps);
665 for i in 0..d {
666 p[i] += 0.5 * eps * grad[i];
667 }
668
669 for step in 0..self.n_leapfrog_steps {
670 for i in 0..d {
672 q[i] += eps * p[i];
673 }
674
675 if step < self.n_leapfrog_steps - 1 {
677 let grad_q = self.grad_log_prob(&q, fd_eps);
678 for i in 0..d {
679 p[i] += eps * grad_q[i];
680 }
681 }
682 }
683
684 let grad_final = self.grad_log_prob(&q, fd_eps);
686 for i in 0..d {
687 p[i] += 0.5 * eps * grad_final[i];
688 }
689
690 for pi in p.iter_mut() {
692 *pi = -*pi;
693 }
694
695 (q, p)
696 }
697}
698
699pub fn effective_sample_size(samples: &[f64]) -> f64 {
706 let n = samples.len();
707 if n < 4 {
708 return n as f64;
709 }
710
711 let b = (n as f64).sqrt() as usize; let n_batches = n / b;
713
714 if n_batches < 2 {
715 return n as f64;
716 }
717
718 let overall_mean = samples.iter().sum::<f64>() / n as f64;
719
720 let chain_var = samples
722 .iter()
723 .map(|&x| (x - overall_mean).powi(2))
724 .sum::<f64>()
725 / (n - 1) as f64;
726
727 if chain_var == 0.0 {
728 return 1.0;
729 }
730
731 let batch_mean_var: f64 = (0..n_batches)
733 .map(|k| {
734 let batch = &samples[k * b..(k + 1) * b];
735 let bm = batch.iter().sum::<f64>() / b as f64;
736 (bm - overall_mean).powi(2)
737 })
738 .sum::<f64>()
739 / (n_batches - 1) as f64;
740
741 let ess = n as f64 * chain_var / (b as f64 * batch_mean_var);
743 ess.clamp(1.0, n as f64)
744}
745
746pub fn gelman_rubin(chains: &[Vec<f64>]) -> f64 {
754 let m = chains.len();
755 if m < 2 {
756 return f64::NAN;
757 }
758
759 let n = chains.iter().map(|c| c.len()).min().unwrap_or(0);
761 if n < 2 {
762 return f64::NAN;
763 }
764
765 let chain_means: Vec<f64> = chains
766 .iter()
767 .map(|c| c[..n].iter().sum::<f64>() / n as f64)
768 .collect();
769 let overall_mean = chain_means.iter().sum::<f64>() / m as f64;
770
771 let b = n as f64
773 * chain_means
774 .iter()
775 .map(|&mu| (mu - overall_mean).powi(2))
776 .sum::<f64>()
777 / (m - 1) as f64;
778
779 let w = chains
781 .iter()
782 .zip(chain_means.iter())
783 .map(|(c, &mu)| c[..n].iter().map(|&x| (x - mu).powi(2)).sum::<f64>() / (n - 1) as f64)
784 .sum::<f64>()
785 / m as f64;
786
787 if w == 0.0 {
788 return f64::NAN;
789 }
790
791 let var_hat = (n - 1) as f64 / n as f64 * w + b / n as f64;
793 (var_hat / w).sqrt()
794}
795
796pub fn autocorrelation(samples: &[f64], lag: usize) -> f64 {
800 let n = samples.len();
801 if n == 0 || lag >= n {
802 return 0.0;
803 }
804 if lag == 0 {
805 return 1.0;
806 }
807
808 let mean = samples.iter().sum::<f64>() / n as f64;
809 let variance = samples.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
810
811 if variance == 0.0 {
812 return 1.0;
813 }
814
815 let n_pairs = n - lag;
816 let cov: f64 = samples[..n_pairs]
817 .iter()
818 .zip(samples[lag..].iter())
819 .map(|(&a, &b)| (a - mean) * (b - mean))
820 .sum::<f64>()
821 / n_pairs as f64;
822
823 cov / variance
824}
825
826pub fn compute_diagnostics(samples: &[Vec<f64>]) -> ChainDiagnostics {
832 compute_diagnostics_with_acceptance(samples, 0.0)
833}
834
835pub(crate) fn compute_diagnostics_with_acceptance(
837 samples: &[Vec<f64>],
838 acceptance_rate: f64,
839) -> ChainDiagnostics {
840 let n = samples.len();
841 if n == 0 {
842 return ChainDiagnostics {
843 n_samples: 0,
844 acceptance_rate,
845 mean: vec![],
846 variance: vec![],
847 effective_sample_size: vec![],
848 r_hat: None,
849 };
850 }
851
852 let d = samples[0].len();
853 let mut mean = vec![0.0_f64; d];
854 let mut variance = vec![0.0_f64; d];
855 let mut ess = vec![0.0_f64; d];
856
857 for dim in 0..d {
858 let col: Vec<f64> = samples.iter().map(|s| s[dim]).collect();
859 let (m, v) = slice_stats(&col);
860 mean[dim] = m;
861 variance[dim] = v;
862 ess[dim] = effective_sample_size(&col);
863 }
864
865 ChainDiagnostics {
866 n_samples: n,
867 acceptance_rate,
868 mean,
869 variance,
870 effective_sample_size: ess,
871 r_hat: None,
872 }
873}
874
875#[cfg(test)]
878mod tests {
879 use super::*;
880
881 #[test]
884 fn test_rng_uniform_in_range() {
885 let mut rng = McmcRng::new(1234);
886 for _ in 0..10_000 {
887 let v = rng.next_f64();
888 assert!(v >= 0.0, "uniform sample below 0: {}", v);
889 assert!(v < 1.0, "uniform sample >= 1: {}", v);
890 }
891 }
892
893 #[test]
894 fn test_rng_normal_mean() {
895 let mut rng = McmcRng::new(42);
896 let samples: Vec<f64> = (0..1000).map(|_| rng.next_normal()).collect();
897 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
898 assert!(
899 mean.abs() < 0.15,
900 "Box-Muller mean too far from 0: {}",
901 mean
902 );
903 }
904
905 #[test]
906 fn test_rng_normal_std() {
907 let mut rng = McmcRng::new(99);
908 let samples: Vec<f64> = (0..1000).map(|_| rng.next_normal()).collect();
909 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
910 let var = samples.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
911 let std = var.sqrt();
912 assert!(
913 (std - 1.0).abs() < 0.15,
914 "Box-Muller std too far from 1: {}",
915 std
916 );
917 }
918
919 #[test]
922 fn test_gaussian_proposal_log_ratio_is_zero() {
923 let proposal = GaussianProposal::new(0.1);
924 let current = vec![1.0, 2.0, 3.0];
925 let proposed = vec![1.1, 2.2, 3.3];
926 assert_eq!(
927 proposal.log_ratio(&proposed, ¤t),
928 0.0,
929 "Gaussian RW should be symmetric"
930 );
931 }
932
933 #[test]
934 fn test_gaussian_proposal_changes_state() {
935 let proposal = GaussianProposal::new(1.0);
936 let mut rng = McmcRng::new(7);
937 let current = vec![0.0, 0.0, 0.0];
938 let proposed = proposal.propose(¤t, &mut rng);
939 assert_ne!(proposed, current, "proposal should change the state");
941 }
942
943 fn standard_normal_lp() -> LogProbFn<impl Fn(&[f64]) -> f64 + Send + Sync> {
947 LogProbFn::new(|theta: &[f64]| -0.5 * theta[0].powi(2))
948 }
949
950 #[test]
951 fn test_mh_standard_normal_mean() {
952 let lp = standard_normal_lp();
953 let proposal = GaussianProposal::new(1.0);
954 let config = McmcConfig::new().n_samples(2000).n_warmup(500).seed(123);
955 let sampler = MetropolisHastings::new(lp, proposal, config);
956 let result = sampler.sample(&[0.0]).expect("sampling failed");
957 let mean = result.posterior_mean()[0];
958 assert!(
959 mean.abs() < 0.3,
960 "MH posterior mean too far from 0: {}",
961 mean
962 );
963 }
964
965 #[test]
966 fn test_mh_standard_normal_variance() {
967 let lp = standard_normal_lp();
968 let proposal = GaussianProposal::new(1.0);
969 let config = McmcConfig::new().n_samples(2000).n_warmup(500).seed(77);
970 let sampler = MetropolisHastings::new(lp, proposal, config);
971 let result = sampler.sample(&[0.0]).expect("sampling failed");
972 let var = result.posterior_variance()[0];
973 assert!(
974 (var - 1.0).abs() < 0.5,
975 "MH posterior variance too far from 1: {}",
976 var
977 );
978 }
979
980 #[test]
981 fn test_mh_acceptance_rate_in_range() {
982 let lp = standard_normal_lp();
983 let proposal = GaussianProposal::new(1.0);
984 let config = McmcConfig::new().n_samples(1000).n_warmup(200).seed(55);
985 let sampler = MetropolisHastings::new(lp, proposal, config);
986 let result = sampler.sample(&[0.0]).expect("sampling failed");
987 let ar = result.diagnostics.acceptance_rate;
988 assert!(ar > 0.0, "acceptance rate should be > 0");
989 assert!(ar <= 1.0, "acceptance rate should be <= 1");
990 }
991
992 #[test]
993 fn test_mh_sample_count_matches_config() {
994 let lp = standard_normal_lp();
995 let proposal = GaussianProposal::new(1.0);
996 let n = 300;
997 let config = McmcConfig::new().n_samples(n).n_warmup(100).seed(11);
998 let sampler = MetropolisHastings::new(lp, proposal, config);
999 let result = sampler.sample(&[0.0]).expect("sampling failed");
1000 assert_eq!(result.n_samples(), n, "sample count should match config");
1001 }
1002
1003 #[test]
1004 fn test_mh_warmup_discarded() {
1005 let lp = standard_normal_lp();
1006 let proposal = GaussianProposal::new(1.0);
1007 let n_samples = 200;
1008 let n_warmup = 100;
1009 let config = McmcConfig::new()
1010 .n_samples(n_samples)
1011 .n_warmup(n_warmup)
1012 .seed(42);
1013 let sampler = MetropolisHastings::new(lp, proposal, config);
1014 let result = sampler.sample(&[0.0]).expect("sampling failed");
1015 assert_eq!(
1017 result.n_samples(),
1018 n_samples,
1019 "warmup samples should not be included in result"
1020 );
1021 }
1022
1023 #[test]
1026 fn test_marginal_samples_correct() {
1027 let samples = vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]];
1028 let result = McmcResult {
1029 log_probs: vec![-1.0, -2.0, -3.0],
1030 diagnostics: compute_diagnostics(&samples),
1031 samples,
1032 };
1033 let m0 = result.marginal_samples(0);
1034 assert_eq!(m0, vec![1.0, 2.0, 3.0]);
1035 let m1 = result.marginal_samples(1);
1036 assert_eq!(m1, vec![10.0, 20.0, 30.0]);
1037 }
1038
1039 #[test]
1040 fn test_credible_interval_contains_true_value() {
1041 let lp = standard_normal_lp();
1042 let proposal = GaussianProposal::new(1.0);
1043 let config = McmcConfig::new().n_samples(2000).n_warmup(500).seed(88);
1044 let sampler = MetropolisHastings::new(lp, proposal, config);
1045 let result = sampler.sample(&[0.0]).expect("sampling failed");
1046 let (lo, hi) = result.credible_interval(0, 0.05); assert!(
1048 lo < 0.0 && 0.0 < hi,
1049 "95% CI should contain the true mean 0.0; got ({}, {})",
1050 lo,
1051 hi
1052 );
1053 }
1054
1055 #[test]
1058 fn test_hmc_standard_normal_mean() {
1059 let lp = LogProbFn::new(|theta: &[f64]| -0.5 * theta[0].powi(2));
1060 let config = McmcConfig::new().n_samples(1000).n_warmup(500).seed(321);
1061 let sampler = HamiltonianMonteCarlo::new(lp, 0.3, 10, config);
1062 let result = sampler.sample(&[0.0]).expect("HMC failed");
1063 let mean = result.posterior_mean()[0];
1064 assert!(
1065 mean.abs() < 0.4,
1066 "HMC posterior mean too far from 0: {}",
1067 mean
1068 );
1069 }
1070
1071 #[test]
1072 fn test_hmc_acceptance_rate_high() {
1073 let lp = LogProbFn::new(|theta: &[f64]| -0.5 * theta[0].powi(2));
1074 let config = McmcConfig::new().n_samples(500).n_warmup(200).seed(999);
1075 let sampler = HamiltonianMonteCarlo::new(lp, 0.1, 5, config);
1077 let result = sampler.sample(&[0.0]).expect("HMC failed");
1078 let ar = result.diagnostics.acceptance_rate;
1079 assert!(
1080 ar > 0.5,
1081 "HMC acceptance rate should be > 0.5 with small step size: {}",
1082 ar
1083 );
1084 }
1085
1086 #[test]
1087 fn test_hmc_gradient_finite_difference_accuracy() {
1088 let hmc = HamiltonianMonteCarlo::new(
1090 LogProbFn::new(|theta: &[f64]| -0.5 * theta[0].powi(2)),
1091 0.1,
1092 5,
1093 McmcConfig::new(),
1094 );
1095 let grad = hmc.grad_log_prob(&[1.0], 1e-5);
1096 assert!(
1097 (grad[0] - (-1.0)).abs() < 1e-6,
1098 "gradient inaccurate: expected -1, got {}",
1099 grad[0]
1100 );
1101 }
1102
1103 #[test]
1106 fn test_ess_positive_for_iid() {
1107 let mut rng = McmcRng::new(1);
1108 let samples: Vec<f64> = (0..200).map(|_| rng.next_normal()).collect();
1109 let ess = effective_sample_size(&samples);
1110 assert!(ess > 0.0, "ESS should be positive");
1111 }
1112
1113 #[test]
1114 fn test_ess_at_most_n_samples() {
1115 let mut rng = McmcRng::new(2);
1116 let samples: Vec<f64> = (0..200).map(|_| rng.next_normal()).collect();
1117 let ess = effective_sample_size(&samples);
1118 assert!(
1119 ess <= samples.len() as f64,
1120 "ESS should not exceed number of samples"
1121 );
1122 }
1123
1124 #[test]
1125 fn test_autocorrelation_lag_zero() {
1126 let samples: Vec<f64> = (0..100).map(|i| i as f64).collect();
1127 let ac = autocorrelation(&samples, 0);
1128 assert!(
1129 (ac - 1.0).abs() < 1e-10,
1130 "autocorrelation at lag 0 should be 1.0, got {}",
1131 ac
1132 );
1133 }
1134
1135 #[test]
1136 fn test_autocorrelation_large_lag_near_zero() {
1137 let mut rng = McmcRng::new(3);
1138 let samples: Vec<f64> = (0..500).map(|_| rng.next_normal()).collect();
1139 let ac = autocorrelation(&samples, 100);
1140 assert!(
1141 ac.abs() < 0.2,
1142 "autocorrelation at large lag should be near 0 for iid: {}",
1143 ac
1144 );
1145 }
1146
1147 #[test]
1148 fn test_gelman_rubin_converged_chains() {
1149 let mut rng = McmcRng::new(5);
1150 let chain1: Vec<f64> = (0..200).map(|_| rng.next_normal()).collect();
1151 let chain2: Vec<f64> = (0..200).map(|_| rng.next_normal()).collect();
1152 let r_hat = gelman_rubin(&[chain1, chain2]);
1153 assert!(
1154 !r_hat.is_nan(),
1155 "R-hat should not be NaN for well-behaved chains"
1156 );
1157 assert!(
1158 r_hat < 1.2,
1159 "R-hat should be near 1.0 for converged chains, got {}",
1160 r_hat
1161 );
1162 }
1163
1164 #[test]
1165 fn test_gelman_rubin_non_converged_chains() {
1166 let chain1: Vec<f64> = (0..200).map(|i| i as f64 * 0.01).collect(); let chain2: Vec<f64> = (0..200).map(|i| 100.0 + i as f64 * 0.01).collect(); let r_hat = gelman_rubin(&[chain1, chain2]);
1170 assert!(
1171 r_hat > 1.1,
1172 "R-hat should be > 1.1 for non-converged chains, got {}",
1173 r_hat
1174 );
1175 }
1176
1177 #[test]
1180 fn test_mcmc_config_builder_pattern() {
1181 let cfg = McmcConfig::new()
1182 .n_samples(500)
1183 .n_warmup(250)
1184 .thin(2)
1185 .seed(17);
1186 assert_eq!(cfg.n_samples, 500);
1187 assert_eq!(cfg.n_warmup, 250);
1188 assert_eq!(cfg.thin, 2);
1189 assert_eq!(cfg.seed, 17);
1190 }
1191
1192 #[test]
1195 fn test_mcmc_error_display() {
1196 let e = McmcError::InvalidConfig("test error".to_string());
1197 let s = e.to_string();
1198 assert!(
1199 s.contains("test error"),
1200 "error Display should contain the message"
1201 );
1202 let e2 = McmcError::DimensionMismatch;
1203 assert!(
1204 e2.to_string().len() > 0,
1205 "DimensionMismatch display should not be empty"
1206 );
1207 let e3 = McmcError::NumericalError("NaN".to_string());
1208 assert!(
1209 e3.to_string().contains("NaN"),
1210 "NumericalError display should contain the message"
1211 );
1212 }
1213}