Skip to main content

tensorlogic_infer/
mcmc.rs

1//! Markov Chain Monte Carlo (MCMC) sampling algorithms.
2//!
3//! Provides general-purpose posterior sampling via:
4//! - **Metropolis-Hastings**: Classic accept/reject sampler with pluggable proposals
5//! - **Hamiltonian Monte Carlo (HMC)**: Gradient-based sampler with leapfrog integration
6//! - **Chain diagnostics**: ESS (batch means), Gelman-Rubin R-hat, autocorrelation
7//!
8//! This module is intentionally distinct from the graphical-model-specific Gibbs sampler
9//! found in `tensorlogic-quantrs-hooks`.
10
11// ─── Traits ─────────────────────────────────────────────────────────────────
12
13/// A log-probability function (unnormalized): given a parameter vector `θ`, returns log p(θ).
14///
15/// Implementations must be `Send + Sync` so they can be used across threads.
16pub trait LogProb: Send + Sync {
17    fn log_prob(&self, theta: &[f64]) -> f64;
18}
19
20/// Convenience wrapper: adapts a closure `F: Fn(&[f64]) -> f64` into a [`LogProb`].
21pub struct LogProbFn<F: Fn(&[f64]) -> f64 + Send + Sync> {
22    f: F,
23}
24
25impl<F: Fn(&[f64]) -> f64 + Send + Sync> LogProbFn<F> {
26    /// Create a new [`LogProbFn`] wrapping the given closure.
27    pub fn new(f: F) -> Self {
28        Self { f }
29    }
30}
31
32impl<F: Fn(&[f64]) -> f64 + Send + Sync> LogProb for LogProbFn<F> {
33    fn log_prob(&self, theta: &[f64]) -> f64 {
34        (self.f)(theta)
35    }
36}
37
38/// Proposal distribution for Metropolis-Hastings sampling.
39pub trait Proposal: Send + Sync {
40    /// Sample a proposed next state given the current state.
41    fn propose(&self, current: &[f64], rng: &mut McmcRng) -> Vec<f64>;
42
43    /// Log proposal ratio: `log q(x|y) − log q(y|x)`.
44    ///
45    /// Returns `0.0` for symmetric proposals (e.g. Gaussian random walk).
46    fn log_ratio(&self, proposed: &[f64], current: &[f64]) -> f64;
47}
48
49// ─── Proposals ───────────────────────────────────────────────────────────────
50
51/// Gaussian random-walk proposal: `θ' = θ + N(0, step_size²)`.
52///
53/// This proposal is symmetric so [`Proposal::log_ratio`] always returns `0.0`.
54#[derive(Debug, Clone)]
55pub struct GaussianProposal {
56    pub step_size: f64,
57}
58
59impl GaussianProposal {
60    /// Create a new Gaussian random-walk proposal with the given step size.
61    pub fn new(step_size: f64) -> Self {
62        Self { step_size }
63    }
64}
65
66impl Proposal for GaussianProposal {
67    fn propose(&self, current: &[f64], rng: &mut McmcRng) -> Vec<f64> {
68        current
69            .iter()
70            .map(|&x| x + rng.next_normal_scaled(0.0, self.step_size))
71            .collect()
72    }
73
74    fn log_ratio(&self, _proposed: &[f64], _current: &[f64]) -> f64 {
75        0.0
76    }
77}
78
79/// Independent Gaussian proposal: draws each dimension independently from `N(mean_i, std_i²)`.
80///
81/// Unlike the random-walk proposal this proposal ignores the current state, so the
82/// log-ratio is generally non-zero and must be computed explicitly.
83#[derive(Debug, Clone)]
84pub struct IndependentGaussianProposal {
85    pub mean: Vec<f64>,
86    pub std: Vec<f64>,
87}
88
89impl IndependentGaussianProposal {
90    /// Create a new independent Gaussian proposal.
91    ///
92    /// # Panics (debug)
93    /// Panics in debug builds if `mean.len() != std.len()`.
94    pub fn new(mean: Vec<f64>, std: Vec<f64>) -> Self {
95        debug_assert_eq!(
96            mean.len(),
97            std.len(),
98            "mean and std must have the same length"
99        );
100        Self { mean, std }
101    }
102}
103
104/// Evaluate `log N(x; mu, sigma²)` (up to constant).
105#[inline]
106fn log_normal_density(x: f64, mu: f64, sigma: f64) -> f64 {
107    let diff = x - mu;
108    -0.5 * (diff / sigma).powi(2) - sigma.ln()
109}
110
111impl Proposal for IndependentGaussianProposal {
112    fn propose(&self, _current: &[f64], rng: &mut McmcRng) -> Vec<f64> {
113        self.mean
114            .iter()
115            .zip(self.std.iter())
116            .map(|(&mu, &sigma)| rng.next_normal_scaled(mu, sigma))
117            .collect()
118    }
119
120    fn log_ratio(&self, proposed: &[f64], current: &[f64]) -> f64 {
121        // log q(current | proposed) - log q(proposed | current)
122        // Both are independent, so:
123        //   log q(x | y) = sum_i log N(x_i; mean_i, std_i)   (ignores y)
124        //   same for log q(y | x)
125        // Therefore: log q(current) - log q(proposed)
126        let log_q_current: f64 = current
127            .iter()
128            .zip(self.mean.iter())
129            .zip(self.std.iter())
130            .map(|((&x, &mu), &sigma)| log_normal_density(x, mu, sigma))
131            .sum();
132        let log_q_proposed: f64 = proposed
133            .iter()
134            .zip(self.mean.iter())
135            .zip(self.std.iter())
136            .map(|((&x, &mu), &sigma)| log_normal_density(x, mu, sigma))
137            .sum();
138        log_q_current - log_q_proposed
139    }
140}
141
142// ─── RNG ─────────────────────────────────────────────────────────────────────
143
144/// A simple, reproducible LCG-based pseudo-random number generator.
145///
146/// Uses the Knuth multiplier LCG with a 64-bit state. Sufficient for MCMC
147/// applications where only statistical quality (not cryptographic security)
148/// is required.
149#[derive(Debug, Clone)]
150pub struct McmcRng {
151    state: u64,
152}
153
154impl McmcRng {
155    /// Create a new RNG with the given seed.
156    pub fn new(seed: u64) -> Self {
157        // Mix the seed to avoid poor low-bit initializations.
158        let state = seed.wrapping_add(6364136223846793005);
159        Self { state }
160    }
161
162    /// Advance the LCG and return the raw 64-bit output.
163    pub fn next_u64(&mut self) -> u64 {
164        // LCG parameters from Knuth / MMIX
165        self.state = self
166            .state
167            .wrapping_mul(6364136223846793005)
168            .wrapping_add(1442695040888963407);
169        self.state
170    }
171
172    /// Return a uniform sample in `[0, 1)`.
173    pub fn next_f64(&mut self) -> f64 {
174        // Use top 53 bits for IEEE 754 double precision mantissa.
175        (self.next_u64() >> 11) as f64 * (1.0_f64 / (1u64 << 53) as f64)
176    }
177
178    /// Return a standard normal sample using the Box-Muller transform.
179    ///
180    /// Samples are generated in pairs; only one is returned per call.
181    pub fn next_normal(&mut self) -> f64 {
182        // Box-Muller: requires two uniform samples
183        let u1 = self.next_f64().max(f64::MIN_POSITIVE); // avoid log(0)
184        let u2 = self.next_f64();
185        let r = (-2.0 * u1.ln()).sqrt();
186        let theta = std::f64::consts::TAU * u2;
187        r * theta.cos()
188    }
189
190    /// Return a normal sample with the given mean and standard deviation.
191    pub fn next_normal_scaled(&mut self, mean: f64, std: f64) -> f64 {
192        mean + std * self.next_normal()
193    }
194}
195
196// ─── Configuration ────────────────────────────────────────────────────────────
197
198/// Configuration shared by all MCMC samplers.
199#[derive(Debug, Clone)]
200pub struct McmcConfig {
201    /// Number of post-warmup samples to collect (default: 1000).
202    pub n_samples: usize,
203    /// Number of burn-in steps to discard (default: 500).
204    pub n_warmup: usize,
205    /// Thinning factor: keep every `thin`-th sample (default: 1).
206    pub thin: usize,
207    /// RNG seed for reproducibility (default: 42).
208    pub seed: u64,
209    /// Target acceptance rate for adaptive step-size tuning (default: 0.234 for MH).
210    pub target_acceptance: f64,
211}
212
213impl Default for McmcConfig {
214    fn default() -> Self {
215        Self {
216            n_samples: 1000,
217            n_warmup: 500,
218            thin: 1,
219            seed: 42,
220            target_acceptance: 0.234,
221        }
222    }
223}
224
225impl McmcConfig {
226    /// Create a new configuration with default values.
227    pub fn new() -> Self {
228        Self::default()
229    }
230
231    /// Set the number of post-warmup samples.
232    pub fn n_samples(mut self, n: usize) -> Self {
233        self.n_samples = n;
234        self
235    }
236
237    /// Set the number of burn-in (warmup) steps.
238    pub fn n_warmup(mut self, n: usize) -> Self {
239        self.n_warmup = n;
240        self
241    }
242
243    /// Set the thinning factor.
244    pub fn thin(mut self, t: usize) -> Self {
245        self.thin = t;
246        self
247    }
248
249    /// Set the RNG seed.
250    pub fn seed(mut self, s: u64) -> Self {
251        self.seed = s;
252        self
253    }
254}
255
256// ─── Results & Diagnostics ────────────────────────────────────────────────────
257
258/// Per-chain diagnostics computed from the collected samples.
259#[derive(Debug, Clone)]
260pub struct ChainDiagnostics {
261    /// Total number of collected samples.
262    pub n_samples: usize,
263    /// Fraction of proposals that were accepted.
264    pub acceptance_rate: f64,
265    /// Per-dimension posterior mean.
266    pub mean: Vec<f64>,
267    /// Per-dimension posterior variance.
268    pub variance: Vec<f64>,
269    /// Effective sample size per dimension (batch-means estimator).
270    pub effective_sample_size: Vec<f64>,
271    /// Gelman-Rubin R-hat per dimension (requires multiple chains; `None` for a single chain).
272    pub r_hat: Option<Vec<f64>>,
273}
274
275/// Complete result returned by an MCMC sampler.
276#[derive(Debug, Clone)]
277pub struct McmcResult {
278    /// Collected samples: outer index is sample index, inner is parameter dimension.
279    pub samples: Vec<Vec<f64>>,
280    /// Log-probability at each collected sample.
281    pub log_probs: Vec<f64>,
282    /// Chain-level diagnostics.
283    pub diagnostics: ChainDiagnostics,
284}
285
286impl McmcResult {
287    /// Number of samples collected.
288    pub fn n_samples(&self) -> usize {
289        self.samples.len()
290    }
291
292    /// Number of parameter dimensions.
293    pub fn n_dims(&self) -> usize {
294        self.samples.first().map(|s| s.len()).unwrap_or(0)
295    }
296
297    /// Extract all samples for a single dimension.
298    pub fn marginal_samples(&self, dim: usize) -> Vec<f64> {
299        self.samples.iter().map(|s| s[dim]).collect()
300    }
301
302    /// Compute the posterior mean across all dimensions.
303    pub fn posterior_mean(&self) -> Vec<f64> {
304        let n = self.n_samples();
305        if n == 0 {
306            return vec![];
307        }
308        let d = self.n_dims();
309        let mut mean = vec![0.0_f64; d];
310        for sample in &self.samples {
311            for (m, &v) in mean.iter_mut().zip(sample.iter()) {
312                *m += v;
313            }
314        }
315        mean.iter_mut().for_each(|m| *m /= n as f64);
316        mean
317    }
318
319    /// Compute the posterior variance (unbiased) across all dimensions.
320    pub fn posterior_variance(&self) -> Vec<f64> {
321        let n = self.n_samples();
322        if n < 2 {
323            return vec![0.0; self.n_dims()];
324        }
325        let mean = self.posterior_mean();
326        let d = self.n_dims();
327        let mut var = vec![0.0_f64; d];
328        for sample in &self.samples {
329            for (v, (&x, &mu)) in var.iter_mut().zip(sample.iter().zip(mean.iter())) {
330                *v += (x - mu).powi(2);
331            }
332        }
333        var.iter_mut().for_each(|v| *v /= (n - 1) as f64);
334        var
335    }
336
337    /// Compute the `(alpha/2, 1 - alpha/2)` credible interval for a single dimension.
338    ///
339    /// Returns `(lower, upper)` quantiles. `alpha = 0.05` gives a 95 % interval.
340    pub fn credible_interval(&self, dim: usize, alpha: f64) -> (f64, f64) {
341        let mut marginal = self.marginal_samples(dim);
342        marginal.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
343        let n = marginal.len();
344        if n == 0 {
345            return (f64::NAN, f64::NAN);
346        }
347        let lo_idx = ((alpha / 2.0) * n as f64) as usize;
348        let hi_idx = ((1.0 - alpha / 2.0) * n as f64) as usize;
349        let lo = marginal[lo_idx.min(n - 1)];
350        let hi = marginal[hi_idx.min(n - 1)];
351        (lo, hi)
352    }
353}
354
355// ─── Error ────────────────────────────────────────────────────────────────────
356
357/// Errors that can arise during MCMC sampling.
358#[derive(Debug)]
359pub enum McmcError {
360    /// The sampler configuration is invalid (e.g. zero samples requested).
361    InvalidConfig(String),
362    /// A dimension mismatch was detected between the initial state and the model.
363    DimensionMismatch,
364    /// A numerical problem was encountered (e.g. NaN in log-probability).
365    NumericalError(String),
366}
367
368impl std::fmt::Display for McmcError {
369    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370        match self {
371            McmcError::InvalidConfig(msg) => write!(f, "MCMC invalid configuration: {}", msg),
372            McmcError::DimensionMismatch => {
373                write!(f, "MCMC dimension mismatch between initial state and model")
374            }
375            McmcError::NumericalError(msg) => write!(f, "MCMC numerical error: {}", msg),
376        }
377    }
378}
379
380impl std::error::Error for McmcError {}
381
382// ─── Internal helpers ─────────────────────────────────────────────────────────
383
384/// Validate a configuration and return an error if it is unusable.
385fn validate_config(config: &McmcConfig) -> Result<(), McmcError> {
386    if config.n_samples == 0 {
387        return Err(McmcError::InvalidConfig(
388            "n_samples must be > 0".to_string(),
389        ));
390    }
391    if config.thin == 0 {
392        return Err(McmcError::InvalidConfig("thin must be > 0".to_string()));
393    }
394    Ok(())
395}
396
397/// Compute basic statistics (mean, variance) over a slice.
398fn slice_stats(data: &[f64]) -> (f64, f64) {
399    let n = data.len();
400    if n == 0 {
401        return (0.0, 0.0);
402    }
403    let mean = data.iter().sum::<f64>() / n as f64;
404    let var = if n < 2 {
405        0.0
406    } else {
407        data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (n - 1) as f64
408    };
409    (mean, var)
410}
411
412// ─── Metropolis-Hastings ──────────────────────────────────────────────────────
413
414/// Metropolis-Hastings MCMC sampler.
415///
416/// Runs a single chain with an arbitrary [`Proposal`] distribution. The chain
417/// is initialized at `initial`, runs for `n_warmup + n_samples * thin` steps,
418/// discards the warmup, and returns every `thin`-th sample.
419pub struct MetropolisHastings<P: LogProb, Q: Proposal> {
420    log_prob: P,
421    proposal: Q,
422    config: McmcConfig,
423}
424
425impl<P: LogProb, Q: Proposal> MetropolisHastings<P, Q> {
426    /// Create a new Metropolis-Hastings sampler.
427    pub fn new(log_prob: P, proposal: Q, config: McmcConfig) -> Self {
428        Self {
429            log_prob,
430            proposal,
431            config,
432        }
433    }
434
435    /// Run the Metropolis-Hastings chain from `initial` and return the collected samples.
436    pub fn sample(&self, initial: &[f64]) -> Result<McmcResult, McmcError> {
437        validate_config(&self.config)?;
438        if initial.is_empty() {
439            return Err(McmcError::InvalidConfig(
440                "initial state must be non-empty".to_string(),
441            ));
442        }
443
444        let mut rng = McmcRng::new(self.config.seed);
445        let total_steps = self.config.n_warmup + self.config.n_samples * self.config.thin;
446
447        let mut current: Vec<f64> = initial.to_vec();
448        let mut current_lp = self.log_prob.log_prob(&current);
449        if !current_lp.is_finite() {
450            return Err(McmcError::NumericalError(
451                "initial state has non-finite log probability".to_string(),
452            ));
453        }
454
455        let mut samples: Vec<Vec<f64>> = Vec::with_capacity(self.config.n_samples);
456        let mut log_probs: Vec<f64> = Vec::with_capacity(self.config.n_samples);
457        let mut n_accepted: usize = 0;
458        let mut step_in_sample: usize = 0; // counts post-warmup steps for thinning
459
460        for step in 0..total_steps {
461            let proposed = self.proposal.propose(&current, &mut rng);
462            let proposed_lp = self.log_prob.log_prob(&proposed);
463
464            let log_accept = if proposed_lp.is_finite() {
465                let log_alpha =
466                    proposed_lp - current_lp + self.proposal.log_ratio(&proposed, &current);
467                log_alpha.min(0.0)
468            } else {
469                f64::NEG_INFINITY
470            };
471
472            let u = rng.next_f64();
473            let accepted = u.ln() < log_accept;
474
475            if accepted {
476                current = proposed;
477                current_lp = proposed_lp;
478                if step >= self.config.n_warmup {
479                    n_accepted += 1;
480                }
481            }
482
483            // Collect post-warmup samples, applying thinning
484            if step >= self.config.n_warmup {
485                step_in_sample += 1;
486                if step_in_sample == self.config.thin {
487                    samples.push(current.clone());
488                    log_probs.push(current_lp);
489                    step_in_sample = 0;
490                }
491            }
492        }
493
494        let n_post_warmup_steps = self.config.n_samples * self.config.thin;
495        let acceptance_rate = if n_post_warmup_steps > 0 {
496            n_accepted as f64 / n_post_warmup_steps as f64
497        } else {
498            0.0
499        };
500
501        let diagnostics = compute_diagnostics_with_acceptance(&samples, acceptance_rate);
502        Ok(McmcResult {
503            samples,
504            log_probs,
505            diagnostics,
506        })
507    }
508}
509
510// ─── Hamiltonian Monte Carlo ──────────────────────────────────────────────────
511
512/// Hamiltonian Monte Carlo (HMC) sampler with leapfrog integration.
513///
514/// Gradients are estimated via central finite differences, so no analytic
515/// gradient implementation is required from the user.
516pub struct HamiltonianMonteCarlo<P: LogProb> {
517    log_prob: P,
518    step_size: f64,
519    n_leapfrog_steps: usize,
520    config: McmcConfig,
521}
522
523impl<P: LogProb> HamiltonianMonteCarlo<P> {
524    /// Create a new HMC sampler.
525    ///
526    /// * `step_size`: leapfrog step size `ε`.
527    /// * `n_leapfrog_steps`: number of leapfrog steps `L` per proposal.
528    pub fn new(log_prob: P, step_size: f64, n_leapfrog_steps: usize, config: McmcConfig) -> Self {
529        Self {
530            log_prob,
531            step_size,
532            n_leapfrog_steps,
533            config,
534        }
535    }
536
537    /// Run the HMC chain from `initial` and return collected samples.
538    pub fn sample(&self, initial: &[f64]) -> Result<McmcResult, McmcError> {
539        validate_config(&self.config)?;
540        if initial.is_empty() {
541            return Err(McmcError::InvalidConfig(
542                "initial state must be non-empty".to_string(),
543            ));
544        }
545        if self.step_size <= 0.0 {
546            return Err(McmcError::InvalidConfig(
547                "step_size must be positive".to_string(),
548            ));
549        }
550        if self.n_leapfrog_steps == 0 {
551            return Err(McmcError::InvalidConfig(
552                "n_leapfrog_steps must be > 0".to_string(),
553            ));
554        }
555
556        let mut rng = McmcRng::new(self.config.seed);
557        let total_steps = self.config.n_warmup + self.config.n_samples * self.config.thin;
558        let d = initial.len();
559
560        let mut current: Vec<f64> = initial.to_vec();
561        let mut current_lp = self.log_prob.log_prob(&current);
562        if !current_lp.is_finite() {
563            return Err(McmcError::NumericalError(
564                "initial state has non-finite log probability".to_string(),
565            ));
566        }
567
568        let mut samples: Vec<Vec<f64>> = Vec::with_capacity(self.config.n_samples);
569        let mut log_probs: Vec<f64> = Vec::with_capacity(self.config.n_samples);
570        let mut n_accepted: usize = 0;
571        let mut step_in_sample: usize = 0;
572
573        for step in 0..total_steps {
574            // Sample momentum r ~ N(0, I)
575            let momentum: Vec<f64> = (0..d).map(|_| rng.next_normal()).collect();
576
577            // Kinetic energy at start: 0.5 * r^T r
578            let ke_old: f64 = momentum.iter().map(|&r| 0.5 * r * r).sum();
579
580            // Leapfrog integration
581            let (proposed, new_momentum) = self.leapfrog(&current, &momentum);
582
583            let proposed_lp = self.log_prob.log_prob(&proposed);
584            let ke_new: f64 = new_momentum.iter().map(|&r| 0.5 * r * r).sum();
585
586            // Hamiltonian H = -log p(θ) + KE
587            let h_old = -current_lp + ke_old;
588            let h_new = -proposed_lp + ke_new;
589
590            let log_accept = if proposed_lp.is_finite() {
591                (h_old - h_new).min(0.0)
592            } else {
593                f64::NEG_INFINITY
594            };
595
596            let u = rng.next_f64();
597            let accepted = u.ln() < log_accept;
598
599            if accepted {
600                current = proposed;
601                current_lp = proposed_lp;
602                if step >= self.config.n_warmup {
603                    n_accepted += 1;
604                }
605            }
606
607            if step >= self.config.n_warmup {
608                step_in_sample += 1;
609                if step_in_sample == self.config.thin {
610                    samples.push(current.clone());
611                    log_probs.push(current_lp);
612                    step_in_sample = 0;
613                }
614            }
615        }
616
617        let n_post_warmup_steps = self.config.n_samples * self.config.thin;
618        let acceptance_rate = if n_post_warmup_steps > 0 {
619            n_accepted as f64 / n_post_warmup_steps as f64
620        } else {
621            0.0
622        };
623
624        let diagnostics = compute_diagnostics_with_acceptance(&samples, acceptance_rate);
625        Ok(McmcResult {
626            samples,
627            log_probs,
628            diagnostics,
629        })
630    }
631
632    /// Estimate the gradient of `log_prob` at `theta` using central finite differences.
633    ///
634    /// Uses step size `eps` for the perturbation.
635    fn grad_log_prob(&self, theta: &[f64], eps: f64) -> Vec<f64> {
636        let d = theta.len();
637        let mut grad = vec![0.0_f64; d];
638        let mut theta_plus = theta.to_vec();
639        let mut theta_minus = theta.to_vec();
640        for i in 0..d {
641            theta_plus[i] = theta[i] + eps;
642            theta_minus[i] = theta[i] - eps;
643            grad[i] = (self.log_prob.log_prob(&theta_plus) - self.log_prob.log_prob(&theta_minus))
644                / (2.0 * eps);
645            theta_plus[i] = theta[i];
646            theta_minus[i] = theta[i];
647        }
648        grad
649    }
650
651    /// Leapfrog integrator: run `L` steps of size `ε` starting from `(theta, momentum)`.
652    ///
653    /// Returns `(theta*, momentum*)`.
654    fn leapfrog(&self, theta: &[f64], momentum: &[f64]) -> (Vec<f64>, Vec<f64>) {
655        let eps = self.step_size;
656        // Finite-difference step for gradient estimation. Choose adaptively.
657        let fd_eps = 1e-5_f64;
658
659        let mut q = theta.to_vec();
660        let mut p = momentum.to_vec();
661        let d = q.len();
662
663        // Half-step for momentum at the start
664        let grad = self.grad_log_prob(&q, fd_eps);
665        for i in 0..d {
666            p[i] += 0.5 * eps * grad[i];
667        }
668
669        for step in 0..self.n_leapfrog_steps {
670            // Full step for position
671            for i in 0..d {
672                q[i] += eps * p[i];
673            }
674
675            // Full step for momentum (except at last step, where it is a half-step)
676            if step < self.n_leapfrog_steps - 1 {
677                let grad_q = self.grad_log_prob(&q, fd_eps);
678                for i in 0..d {
679                    p[i] += eps * grad_q[i];
680                }
681            }
682        }
683
684        // Final half-step for momentum
685        let grad_final = self.grad_log_prob(&q, fd_eps);
686        for i in 0..d {
687            p[i] += 0.5 * eps * grad_final[i];
688        }
689
690        // Negate momentum to make the proposal reversible
691        for pi in p.iter_mut() {
692            *pi = -*pi;
693        }
694
695        (q, p)
696    }
697}
698
699// ─── Diagnostics ─────────────────────────────────────────────────────────────
700
701/// Compute the effective sample size (ESS) using the batch-means estimator.
702///
703/// Partitions the chain into `sqrt(n)` batches, computes the batch means,
704/// and estimates the variance of the chain mean. Returns a value in `[1, n]`.
705pub fn effective_sample_size(samples: &[f64]) -> f64 {
706    let n = samples.len();
707    if n < 4 {
708        return n as f64;
709    }
710
711    let b = (n as f64).sqrt() as usize; // batch size
712    let n_batches = n / b;
713
714    if n_batches < 2 {
715        return n as f64;
716    }
717
718    let overall_mean = samples.iter().sum::<f64>() / n as f64;
719
720    // Variance of the overall chain (naive)
721    let chain_var = samples
722        .iter()
723        .map(|&x| (x - overall_mean).powi(2))
724        .sum::<f64>()
725        / (n - 1) as f64;
726
727    if chain_var == 0.0 {
728        return 1.0;
729    }
730
731    // Variance of batch means
732    let batch_mean_var: f64 = (0..n_batches)
733        .map(|k| {
734            let batch = &samples[k * b..(k + 1) * b];
735            let bm = batch.iter().sum::<f64>() / b as f64;
736            (bm - overall_mean).powi(2)
737        })
738        .sum::<f64>()
739        / (n_batches - 1) as f64;
740
741    // ESS = n * chain_var / (b * batch_mean_var)
742    let ess = n as f64 * chain_var / (b as f64 * batch_mean_var);
743    ess.clamp(1.0, n as f64)
744}
745
746/// Compute the Gelman-Rubin R-hat statistic for a set of independent chains.
747///
748/// Returns a value close to 1.0 for converged chains, and > 1.1 indicates
749/// potential non-convergence. Requires at least 2 chains.
750///
751/// # Panics
752/// Returns `f64::NAN` if `chains` is empty or all chains have zero variance.
753pub fn gelman_rubin(chains: &[Vec<f64>]) -> f64 {
754    let m = chains.len();
755    if m < 2 {
756        return f64::NAN;
757    }
758
759    // All chains should have the same length; use the minimum length.
760    let n = chains.iter().map(|c| c.len()).min().unwrap_or(0);
761    if n < 2 {
762        return f64::NAN;
763    }
764
765    let chain_means: Vec<f64> = chains
766        .iter()
767        .map(|c| c[..n].iter().sum::<f64>() / n as f64)
768        .collect();
769    let overall_mean = chain_means.iter().sum::<f64>() / m as f64;
770
771    // Between-chain variance B
772    let b = n as f64
773        * chain_means
774            .iter()
775            .map(|&mu| (mu - overall_mean).powi(2))
776            .sum::<f64>()
777        / (m - 1) as f64;
778
779    // Within-chain variance W (average of per-chain variances)
780    let w = chains
781        .iter()
782        .zip(chain_means.iter())
783        .map(|(c, &mu)| c[..n].iter().map(|&x| (x - mu).powi(2)).sum::<f64>() / (n - 1) as f64)
784        .sum::<f64>()
785        / m as f64;
786
787    if w == 0.0 {
788        return f64::NAN;
789    }
790
791    // Pooled variance estimate
792    let var_hat = (n - 1) as f64 / n as f64 * w + b / n as f64;
793    (var_hat / w).sqrt()
794}
795
796/// Compute the autocorrelation of `samples` at a given lag `k`.
797///
798/// Returns a value in `[-1, 1]`; lag 0 always returns `1.0`.
799pub fn autocorrelation(samples: &[f64], lag: usize) -> f64 {
800    let n = samples.len();
801    if n == 0 || lag >= n {
802        return 0.0;
803    }
804    if lag == 0 {
805        return 1.0;
806    }
807
808    let mean = samples.iter().sum::<f64>() / n as f64;
809    let variance = samples.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
810
811    if variance == 0.0 {
812        return 1.0;
813    }
814
815    let n_pairs = n - lag;
816    let cov: f64 = samples[..n_pairs]
817        .iter()
818        .zip(samples[lag..].iter())
819        .map(|(&a, &b)| (a - mean) * (b - mean))
820        .sum::<f64>()
821        / n_pairs as f64;
822
823    cov / variance
824}
825
826/// Compute [`ChainDiagnostics`] from a collection of samples (one per row).
827///
828/// This version assumes a single chain (so `r_hat` will be `None`).
829/// Acceptance rate is set to `0.0`; use `compute_diagnostics_with_acceptance`
830/// when you have the true acceptance rate.
831pub fn compute_diagnostics(samples: &[Vec<f64>]) -> ChainDiagnostics {
832    compute_diagnostics_with_acceptance(samples, 0.0)
833}
834
835/// Internal helper: compute diagnostics with a known acceptance rate.
836pub(crate) fn compute_diagnostics_with_acceptance(
837    samples: &[Vec<f64>],
838    acceptance_rate: f64,
839) -> ChainDiagnostics {
840    let n = samples.len();
841    if n == 0 {
842        return ChainDiagnostics {
843            n_samples: 0,
844            acceptance_rate,
845            mean: vec![],
846            variance: vec![],
847            effective_sample_size: vec![],
848            r_hat: None,
849        };
850    }
851
852    let d = samples[0].len();
853    let mut mean = vec![0.0_f64; d];
854    let mut variance = vec![0.0_f64; d];
855    let mut ess = vec![0.0_f64; d];
856
857    for dim in 0..d {
858        let col: Vec<f64> = samples.iter().map(|s| s[dim]).collect();
859        let (m, v) = slice_stats(&col);
860        mean[dim] = m;
861        variance[dim] = v;
862        ess[dim] = effective_sample_size(&col);
863    }
864
865    ChainDiagnostics {
866        n_samples: n,
867        acceptance_rate,
868        mean,
869        variance,
870        effective_sample_size: ess,
871        r_hat: None,
872    }
873}
874
875// ─── Tests ────────────────────────────────────────────────────────────────────
876
877#[cfg(test)]
878mod tests {
879    use super::*;
880
881    // ── McmcRng ──────────────────────────────────────────────────────────────
882
883    #[test]
884    fn test_rng_uniform_in_range() {
885        let mut rng = McmcRng::new(1234);
886        for _ in 0..10_000 {
887            let v = rng.next_f64();
888            assert!(v >= 0.0, "uniform sample below 0: {}", v);
889            assert!(v < 1.0, "uniform sample >= 1: {}", v);
890        }
891    }
892
893    #[test]
894    fn test_rng_normal_mean() {
895        let mut rng = McmcRng::new(42);
896        let samples: Vec<f64> = (0..1000).map(|_| rng.next_normal()).collect();
897        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
898        assert!(
899            mean.abs() < 0.15,
900            "Box-Muller mean too far from 0: {}",
901            mean
902        );
903    }
904
905    #[test]
906    fn test_rng_normal_std() {
907        let mut rng = McmcRng::new(99);
908        let samples: Vec<f64> = (0..1000).map(|_| rng.next_normal()).collect();
909        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
910        let var = samples.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
911        let std = var.sqrt();
912        assert!(
913            (std - 1.0).abs() < 0.15,
914            "Box-Muller std too far from 1: {}",
915            std
916        );
917    }
918
919    // ── GaussianProposal ─────────────────────────────────────────────────────
920
921    #[test]
922    fn test_gaussian_proposal_log_ratio_is_zero() {
923        let proposal = GaussianProposal::new(0.1);
924        let current = vec![1.0, 2.0, 3.0];
925        let proposed = vec![1.1, 2.2, 3.3];
926        assert_eq!(
927            proposal.log_ratio(&proposed, &current),
928            0.0,
929            "Gaussian RW should be symmetric"
930        );
931    }
932
933    #[test]
934    fn test_gaussian_proposal_changes_state() {
935        let proposal = GaussianProposal::new(1.0);
936        let mut rng = McmcRng::new(7);
937        let current = vec![0.0, 0.0, 0.0];
938        let proposed = proposal.propose(&current, &mut rng);
939        // It is astronomically unlikely for all three to remain exactly 0.
940        assert_ne!(proposed, current, "proposal should change the state");
941    }
942
943    // ── MetropolisHastings ───────────────────────────────────────────────────
944
945    /// Standard normal target: log p(θ) = -0.5 * θ²
946    fn standard_normal_lp() -> LogProbFn<impl Fn(&[f64]) -> f64 + Send + Sync> {
947        LogProbFn::new(|theta: &[f64]| -0.5 * theta[0].powi(2))
948    }
949
950    #[test]
951    fn test_mh_standard_normal_mean() {
952        let lp = standard_normal_lp();
953        let proposal = GaussianProposal::new(1.0);
954        let config = McmcConfig::new().n_samples(2000).n_warmup(500).seed(123);
955        let sampler = MetropolisHastings::new(lp, proposal, config);
956        let result = sampler.sample(&[0.0]).expect("sampling failed");
957        let mean = result.posterior_mean()[0];
958        assert!(
959            mean.abs() < 0.3,
960            "MH posterior mean too far from 0: {}",
961            mean
962        );
963    }
964
965    #[test]
966    fn test_mh_standard_normal_variance() {
967        let lp = standard_normal_lp();
968        let proposal = GaussianProposal::new(1.0);
969        let config = McmcConfig::new().n_samples(2000).n_warmup(500).seed(77);
970        let sampler = MetropolisHastings::new(lp, proposal, config);
971        let result = sampler.sample(&[0.0]).expect("sampling failed");
972        let var = result.posterior_variance()[0];
973        assert!(
974            (var - 1.0).abs() < 0.5,
975            "MH posterior variance too far from 1: {}",
976            var
977        );
978    }
979
980    #[test]
981    fn test_mh_acceptance_rate_in_range() {
982        let lp = standard_normal_lp();
983        let proposal = GaussianProposal::new(1.0);
984        let config = McmcConfig::new().n_samples(1000).n_warmup(200).seed(55);
985        let sampler = MetropolisHastings::new(lp, proposal, config);
986        let result = sampler.sample(&[0.0]).expect("sampling failed");
987        let ar = result.diagnostics.acceptance_rate;
988        assert!(ar > 0.0, "acceptance rate should be > 0");
989        assert!(ar <= 1.0, "acceptance rate should be <= 1");
990    }
991
992    #[test]
993    fn test_mh_sample_count_matches_config() {
994        let lp = standard_normal_lp();
995        let proposal = GaussianProposal::new(1.0);
996        let n = 300;
997        let config = McmcConfig::new().n_samples(n).n_warmup(100).seed(11);
998        let sampler = MetropolisHastings::new(lp, proposal, config);
999        let result = sampler.sample(&[0.0]).expect("sampling failed");
1000        assert_eq!(result.n_samples(), n, "sample count should match config");
1001    }
1002
1003    #[test]
1004    fn test_mh_warmup_discarded() {
1005        let lp = standard_normal_lp();
1006        let proposal = GaussianProposal::new(1.0);
1007        let n_samples = 200;
1008        let n_warmup = 100;
1009        let config = McmcConfig::new()
1010            .n_samples(n_samples)
1011            .n_warmup(n_warmup)
1012            .seed(42);
1013        let sampler = MetropolisHastings::new(lp, proposal, config);
1014        let result = sampler.sample(&[0.0]).expect("sampling failed");
1015        // Result should contain exactly n_samples, not n_samples + n_warmup
1016        assert_eq!(
1017            result.n_samples(),
1018            n_samples,
1019            "warmup samples should not be included in result"
1020        );
1021    }
1022
1023    // ── McmcResult ───────────────────────────────────────────────────────────
1024
1025    #[test]
1026    fn test_marginal_samples_correct() {
1027        let samples = vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]];
1028        let result = McmcResult {
1029            log_probs: vec![-1.0, -2.0, -3.0],
1030            diagnostics: compute_diagnostics(&samples),
1031            samples,
1032        };
1033        let m0 = result.marginal_samples(0);
1034        assert_eq!(m0, vec![1.0, 2.0, 3.0]);
1035        let m1 = result.marginal_samples(1);
1036        assert_eq!(m1, vec![10.0, 20.0, 30.0]);
1037    }
1038
1039    #[test]
1040    fn test_credible_interval_contains_true_value() {
1041        let lp = standard_normal_lp();
1042        let proposal = GaussianProposal::new(1.0);
1043        let config = McmcConfig::new().n_samples(2000).n_warmup(500).seed(88);
1044        let sampler = MetropolisHastings::new(lp, proposal, config);
1045        let result = sampler.sample(&[0.0]).expect("sampling failed");
1046        let (lo, hi) = result.credible_interval(0, 0.05); // 95% CI
1047        assert!(
1048            lo < 0.0 && 0.0 < hi,
1049            "95% CI should contain the true mean 0.0; got ({}, {})",
1050            lo,
1051            hi
1052        );
1053    }
1054
1055    // ── HamiltonianMonteCarlo ────────────────────────────────────────────────
1056
1057    #[test]
1058    fn test_hmc_standard_normal_mean() {
1059        let lp = LogProbFn::new(|theta: &[f64]| -0.5 * theta[0].powi(2));
1060        let config = McmcConfig::new().n_samples(1000).n_warmup(500).seed(321);
1061        let sampler = HamiltonianMonteCarlo::new(lp, 0.3, 10, config);
1062        let result = sampler.sample(&[0.0]).expect("HMC failed");
1063        let mean = result.posterior_mean()[0];
1064        assert!(
1065            mean.abs() < 0.4,
1066            "HMC posterior mean too far from 0: {}",
1067            mean
1068        );
1069    }
1070
1071    #[test]
1072    fn test_hmc_acceptance_rate_high() {
1073        let lp = LogProbFn::new(|theta: &[f64]| -0.5 * theta[0].powi(2));
1074        let config = McmcConfig::new().n_samples(500).n_warmup(200).seed(999);
1075        // Small step + few leapfrog steps should have high acceptance
1076        let sampler = HamiltonianMonteCarlo::new(lp, 0.1, 5, config);
1077        let result = sampler.sample(&[0.0]).expect("HMC failed");
1078        let ar = result.diagnostics.acceptance_rate;
1079        assert!(
1080            ar > 0.5,
1081            "HMC acceptance rate should be > 0.5 with small step size: {}",
1082            ar
1083        );
1084    }
1085
1086    #[test]
1087    fn test_hmc_gradient_finite_difference_accuracy() {
1088        // For f(x) = -0.5 x^2, the gradient at x=1 should be -1.
1089        let hmc = HamiltonianMonteCarlo::new(
1090            LogProbFn::new(|theta: &[f64]| -0.5 * theta[0].powi(2)),
1091            0.1,
1092            5,
1093            McmcConfig::new(),
1094        );
1095        let grad = hmc.grad_log_prob(&[1.0], 1e-5);
1096        assert!(
1097            (grad[0] - (-1.0)).abs() < 1e-6,
1098            "gradient inaccurate: expected -1, got {}",
1099            grad[0]
1100        );
1101    }
1102
1103    // ── Diagnostics ──────────────────────────────────────────────────────────
1104
1105    #[test]
1106    fn test_ess_positive_for_iid() {
1107        let mut rng = McmcRng::new(1);
1108        let samples: Vec<f64> = (0..200).map(|_| rng.next_normal()).collect();
1109        let ess = effective_sample_size(&samples);
1110        assert!(ess > 0.0, "ESS should be positive");
1111    }
1112
1113    #[test]
1114    fn test_ess_at_most_n_samples() {
1115        let mut rng = McmcRng::new(2);
1116        let samples: Vec<f64> = (0..200).map(|_| rng.next_normal()).collect();
1117        let ess = effective_sample_size(&samples);
1118        assert!(
1119            ess <= samples.len() as f64,
1120            "ESS should not exceed number of samples"
1121        );
1122    }
1123
1124    #[test]
1125    fn test_autocorrelation_lag_zero() {
1126        let samples: Vec<f64> = (0..100).map(|i| i as f64).collect();
1127        let ac = autocorrelation(&samples, 0);
1128        assert!(
1129            (ac - 1.0).abs() < 1e-10,
1130            "autocorrelation at lag 0 should be 1.0, got {}",
1131            ac
1132        );
1133    }
1134
1135    #[test]
1136    fn test_autocorrelation_large_lag_near_zero() {
1137        let mut rng = McmcRng::new(3);
1138        let samples: Vec<f64> = (0..500).map(|_| rng.next_normal()).collect();
1139        let ac = autocorrelation(&samples, 100);
1140        assert!(
1141            ac.abs() < 0.2,
1142            "autocorrelation at large lag should be near 0 for iid: {}",
1143            ac
1144        );
1145    }
1146
1147    #[test]
1148    fn test_gelman_rubin_converged_chains() {
1149        let mut rng = McmcRng::new(5);
1150        let chain1: Vec<f64> = (0..200).map(|_| rng.next_normal()).collect();
1151        let chain2: Vec<f64> = (0..200).map(|_| rng.next_normal()).collect();
1152        let r_hat = gelman_rubin(&[chain1, chain2]);
1153        assert!(
1154            !r_hat.is_nan(),
1155            "R-hat should not be NaN for well-behaved chains"
1156        );
1157        assert!(
1158            r_hat < 1.2,
1159            "R-hat should be near 1.0 for converged chains, got {}",
1160            r_hat
1161        );
1162    }
1163
1164    #[test]
1165    fn test_gelman_rubin_non_converged_chains() {
1166        // Two chains drawn from very different distributions
1167        let chain1: Vec<f64> = (0..200).map(|i| i as f64 * 0.01).collect(); // near 0..2
1168        let chain2: Vec<f64> = (0..200).map(|i| 100.0 + i as f64 * 0.01).collect(); // near 100..102
1169        let r_hat = gelman_rubin(&[chain1, chain2]);
1170        assert!(
1171            r_hat > 1.1,
1172            "R-hat should be > 1.1 for non-converged chains, got {}",
1173            r_hat
1174        );
1175    }
1176
1177    // ── McmcConfig builder ───────────────────────────────────────────────────
1178
1179    #[test]
1180    fn test_mcmc_config_builder_pattern() {
1181        let cfg = McmcConfig::new()
1182            .n_samples(500)
1183            .n_warmup(250)
1184            .thin(2)
1185            .seed(17);
1186        assert_eq!(cfg.n_samples, 500);
1187        assert_eq!(cfg.n_warmup, 250);
1188        assert_eq!(cfg.thin, 2);
1189        assert_eq!(cfg.seed, 17);
1190    }
1191
1192    // ── McmcError Display ────────────────────────────────────────────────────
1193
1194    #[test]
1195    fn test_mcmc_error_display() {
1196        let e = McmcError::InvalidConfig("test error".to_string());
1197        let s = e.to_string();
1198        assert!(
1199            s.contains("test error"),
1200            "error Display should contain the message"
1201        );
1202        let e2 = McmcError::DimensionMismatch;
1203        assert!(
1204            e2.to_string().len() > 0,
1205            "DimensionMismatch display should not be empty"
1206        );
1207        let e3 = McmcError::NumericalError("NaN".to_string());
1208        assert!(
1209            e3.to_string().contains("NaN"),
1210            "NumericalError display should contain the message"
1211        );
1212    }
1213}