Skip to main content

scirs2_stats/mcmc/
hamiltonian.rs

1//! Hamiltonian Monte Carlo (HMC) sampling
2//!
3//! HMC is a sophisticated MCMC method that uses gradient information to make
4//! more efficient proposals than random walk methods.
5
6use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::{Distribution, Normal};
9use scirs2_core::validation::*;
10use scirs2_core::{Rng, RngExt};
11use std::fmt::Debug;
12
13/// Target distribution trait with gradient information for HMC
14pub trait DifferentiableTarget: Send + Sync {
15    /// Compute the log probability density
16    fn log_density(&self, x: &Array1<f64>) -> f64;
17
18    /// Compute the gradient of the log density
19    fn gradient(&self, x: &Array1<f64>) -> Array1<f64>;
20
21    /// Get the dimensionality
22    fn dim(&self) -> usize;
23
24    /// Optional: compute both log density and gradient together for efficiency
25    fn log_density_and_gradient(&self, x: &Array1<f64>) -> (f64, Array1<f64>) {
26        (self.log_density(x), self.gradient(x))
27    }
28}
29
30/// Hamiltonian Monte Carlo sampler
31pub struct HamiltonianMonteCarlo<T: DifferentiableTarget> {
32    /// Target distribution
33    pub target: T,
34    /// Current position
35    pub position: Array1<f64>,
36    /// Current log density
37    pub current_log_density: f64,
38    /// Step size for leapfrog integration
39    pub stepsize: f64,
40    /// Number of leapfrog steps
41    pub n_steps: usize,
42    /// Mass matrix (identity for standard HMC)
43    pub mass_matrix: Array2<f64>,
44    /// Mass matrix inverse
45    pub mass_inv: Array2<f64>,
46    /// Number of accepted proposals
47    pub n_accepted: usize,
48    /// Total number of proposals
49    pub n_proposed: usize,
50}
51
52impl<T: DifferentiableTarget> HamiltonianMonteCarlo<T> {
53    /// Create a new HMC sampler
54    pub fn new(target: T, initial: Array1<f64>, stepsize: f64, nsteps: usize) -> Result<Self> {
55        checkarray_finite(&initial, "initial")?;
56        check_positive(stepsize, "stepsize")?;
57        check_positive(nsteps, "nsteps")?;
58
59        if initial.len() != target.dim() {
60            return Err(StatsError::DimensionMismatch(format!(
61                "initial dimension ({}) must match target dimension ({})",
62                initial.len(),
63                target.dim()
64            )));
65        }
66
67        let dim = initial.len();
68        let mass_matrix = Array2::eye(dim);
69        let mass_inv = Array2::eye(dim);
70        let current_log_density = target.log_density(&initial);
71
72        Ok(Self {
73            target,
74            position: initial,
75            current_log_density,
76            stepsize,
77            n_steps: nsteps,
78            mass_matrix,
79            mass_inv,
80            n_accepted: 0,
81            n_proposed: 0,
82        })
83    }
84
85    /// Set custom mass matrix
86    pub fn with_mass_matrix(mut self, massmatrix: Array2<f64>) -> Result<Self> {
87        checkarray_finite(&massmatrix, "massmatrix")?;
88
89        if massmatrix.nrows() != self.position.len() || massmatrix.ncols() != self.position.len() {
90            return Err(StatsError::DimensionMismatch(format!(
91                "massmatrix shape ({}, {}) must be ({}, {})",
92                massmatrix.nrows(),
93                massmatrix.ncols(),
94                self.position.len(),
95                self.position.len()
96            )));
97        }
98
99        // Compute inverse
100        let mass_inv = scirs2_linalg::inv(&massmatrix.view(), None).map_err(|e| {
101            StatsError::ComputationError(format!("Failed to invert mass matrix: {}", e))
102        })?;
103
104        self.mass_matrix = massmatrix;
105        self.mass_inv = mass_inv;
106        Ok(self)
107    }
108
109    /// Perform one HMC step
110    pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<Array1<f64>> {
111        let _dim = self.position.len();
112
113        // Sample momentum from N(0, M)
114        let momentum = self.sample_momentum(rng)?;
115
116        // Store initial state
117        let initial_position = self.position.clone();
118        let initial_momentum = momentum.clone();
119        let initial_log_density = self.current_log_density;
120
121        // Perform leapfrog integration
122        let (final_position, final_momentum) = self.leapfrog(initial_position.clone(), momentum)?;
123
124        // Compute Hamiltonian for initial and final states
125        let initial_hamiltonian =
126            -initial_log_density + 0.5 * self.kinetic_energy(&initial_momentum);
127        let final_log_density = self.target.log_density(&final_position);
128        let final_hamiltonian = -final_log_density + 0.5 * self.kinetic_energy(&final_momentum);
129
130        // Metropolis acceptance step
131        let log_alpha = -(final_hamiltonian - initial_hamiltonian);
132        let u: f64 = rng.random();
133
134        self.n_proposed += 1;
135
136        if u.ln() < log_alpha {
137            // Accept
138            self.position = final_position;
139            self.current_log_density = final_log_density;
140            self.n_accepted += 1;
141        }
142        // If rejected, keep current position
143
144        Ok(self.position.clone())
145    }
146
147    /// Sample momentum from N(0, M)
148    fn sample_momentum<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<Array1<f64>> {
149        let dim = self.position.len();
150        let normal = Normal::new(0.0, 1.0).map_err(|e| {
151            StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
152        })?;
153
154        // Sample from standard normal
155        let z = Array1::from_shape_fn(dim, |_| normal.sample(rng));
156
157        // Transform to p ~ N(0, M) using Cholesky decomposition
158        // For simplicity, assume diagonal mass matrix
159        let mut momentum = Array1::zeros(dim);
160        for i in 0..dim {
161            momentum[i] = z[i] * self.mass_matrix[[i, i]].sqrt();
162        }
163
164        Ok(momentum)
165    }
166
167    /// Compute kinetic energy: 0.5 * p^T * M^{-1} * p
168    fn kinetic_energy(&self, momentum: &Array1<f64>) -> f64 {
169        // For diagonal mass matrix, this simplifies
170        let mut energy = 0.0;
171        for i in 0..momentum.len() {
172            energy += momentum[i] * momentum[i] * self.mass_inv[[i, i]];
173        }
174        0.5 * energy
175    }
176
177    /// Leapfrog integration
178    fn leapfrog(
179        &self,
180        mut position: Array1<f64>,
181        mut momentum: Array1<f64>,
182    ) -> Result<(Array1<f64>, Array1<f64>)> {
183        // Initial half step for momentum
184        let gradient = self.target.gradient(&position);
185        momentum = momentum + 0.5 * self.stepsize * gradient;
186
187        // Alternating full steps
188        for _ in 0..self.n_steps {
189            // Full step for position
190            let momentum_update = self.mass_inv.dot(&momentum);
191            position = position + self.stepsize * momentum_update;
192
193            // Full step for momentum (except last iteration)
194            if self.n_steps > 1 {
195                let gradient = self.target.gradient(&position);
196                momentum = momentum + self.stepsize * gradient;
197            }
198        }
199
200        // Final half step for momentum
201        let gradient = self.target.gradient(&position);
202        momentum = momentum + 0.5 * self.stepsize * gradient;
203
204        // Negate momentum for reversibility
205        momentum = -momentum;
206
207        Ok((position, momentum))
208    }
209
210    /// Sample multiple states
211    pub fn sample<R: Rng + ?Sized>(
212        &mut self,
213        n_samples_: usize,
214        rng: &mut R,
215    ) -> Result<Array2<f64>> {
216        let dim = self.position.len();
217        let mut samples = Array2::zeros((n_samples_, dim));
218
219        for i in 0..n_samples_ {
220            let sample = self.step(rng)?;
221            samples.row_mut(i).assign(&sample);
222        }
223
224        Ok(samples)
225    }
226
227    /// Sample with burn-in
228    pub fn sample_with_burnin<R: Rng + ?Sized>(
229        &mut self,
230        n_samples_: usize,
231        burnin: usize,
232        rng: &mut R,
233    ) -> Result<Array2<f64>> {
234        check_positive(burnin, "burnin")?;
235
236        // Burn-in
237        for _ in 0..burnin {
238            self.step(rng)?;
239        }
240
241        // Reset counters after burn-in
242        self.reset_counters();
243
244        // Collect _samples
245        self.sample(n_samples_, rng)
246    }
247
248    /// Get acceptance rate
249    pub fn acceptance_rate(&self) -> f64 {
250        if self.n_proposed == 0 {
251            0.0
252        } else {
253            self.n_accepted as f64 / self.n_proposed as f64
254        }
255    }
256
257    /// Reset acceptance counters
258    pub fn reset_counters(&mut self) {
259        self.n_accepted = 0;
260        self.n_proposed = 0;
261    }
262}
263
264/// No-U-Turn Sampler (NUTS) - adaptive version of HMC
265pub struct NoUTurnSampler<T: DifferentiableTarget> {
266    /// Base HMC sampler
267    pub hmc: HamiltonianMonteCarlo<T>,
268    /// Maximum tree depth
269    pub max_tree_depth: usize,
270    /// Target acceptance probability
271    pub target_accept_prob: f64,
272    /// Step size adaptation parameters
273    pub stepsize_adaptation: DualAveragingAdaptation,
274}
275
276/// Dual averaging adaptation for step size
277#[derive(Debug, Clone)]
278pub struct DualAveragingAdaptation {
279    /// Target acceptance probability
280    pub target: f64,
281    /// Shrinkage target for log step size
282    pub gamma: f64,
283    /// Relaxation exponent
284    pub t0: f64,
285    /// Adaptation rate
286    pub kappa: f64,
287    /// Current iteration
288    pub iteration: usize,
289    /// Log step size average
290    pub log_step_avg: f64,
291    /// H statistic accumulator
292    pub h_avg: f64,
293}
294
295impl DualAveragingAdaptation {
296    /// Create new dual averaging adaptation
297    pub fn new(target: f64, initial_logstep: f64) -> Self {
298        Self {
299            target,
300            gamma: 0.05,
301            t0: 10.0,
302            kappa: 0.75,
303            iteration: 0,
304            log_step_avg: initial_logstep,
305            h_avg: 0.0,
306        }
307    }
308
309    /// Update step size based on acceptance probability
310    pub fn update(&mut self, alpha: f64) -> f64 {
311        self.iteration += 1;
312        let m = self.iteration as f64;
313
314        // Update H statistic
315        self.h_avg =
316            (1.0 - 1.0 / (m + self.t0)) * self.h_avg + (self.target - alpha) / (m + self.t0);
317
318        // Update log step size
319        let log_step = self.log_step_avg - self.h_avg / (self.gamma * m.powf(self.kappa));
320
321        // Update average
322        let weight = m.powf(-self.kappa);
323        self.log_step_avg = (1.0 - weight) * self.log_step_avg + weight * log_step;
324
325        log_step.exp()
326    }
327}
328
329impl<T: DifferentiableTarget> NoUTurnSampler<T> {
330    /// Create new NUTS sampler
331    pub fn new(target: T, initial: Array1<f64>, initial_stepsize: f64) -> Result<Self> {
332        let hmc = HamiltonianMonteCarlo::new(target, initial, initial_stepsize, 1)?;
333        let stepsize_adaptation = DualAveragingAdaptation::new(0.8, initial_stepsize.ln());
334
335        Ok(Self {
336            hmc,
337            max_tree_depth: 10,
338            target_accept_prob: 0.8,
339            stepsize_adaptation,
340        })
341    }
342
343    /// Perform one NUTS step
344    pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<Array1<f64>> {
345        // Sample momentum
346        let momentum = self.hmc.sample_momentum(rng)?;
347
348        // Build tree and sample
349        let (new_position, alpha) =
350            self.build_tree(self.hmc.position.clone(), momentum, 0.0, 1, rng)?;
351
352        // Update step size during adaptation
353        let new_stepsize = self.stepsize_adaptation.update(alpha);
354        self.hmc.stepsize = new_stepsize;
355
356        // Update position if different
357        if !new_position
358            .iter()
359            .zip(self.hmc.position.iter())
360            .all(|(a, b)| (a - b).abs() < f64::EPSILON)
361        {
362            self.hmc.position = new_position;
363            self.hmc.current_log_density = self.hmc.target.log_density(&self.hmc.position);
364            self.hmc.n_accepted += 1;
365        }
366
367        self.hmc.n_proposed += 1;
368        Ok(self.hmc.position.clone())
369    }
370
371    /// Build tree for NUTS algorithm (simplified version)
372    fn build_tree<R: Rng + ?Sized>(
373        &self,
374        position: Array1<f64>,
375        momentum: Array1<f64>,
376        log_u: f64,
377        depth: usize,
378        rng: &mut R,
379    ) -> Result<(Array1<f64>, f64)> {
380        if depth >= self.max_tree_depth {
381            // Base case: return input position with low acceptance
382            return Ok((position, 0.0));
383        }
384
385        // Perform leapfrog step
386        let (new_position, new_momentum) = self.hmc.leapfrog(position.clone(), momentum.clone())?;
387
388        // Compute log probability for new state
389        let new_log_density = self.hmc.target.log_density(&new_position);
390        let new_hamiltonian = -new_log_density + 0.5 * self.hmc.kinetic_energy(&new_momentum);
391
392        // Check if proposal is acceptable
393        let current_hamiltonian =
394            -self.hmc.current_log_density + 0.5 * self.hmc.kinetic_energy(&momentum);
395        let log_alpha = -(new_hamiltonian - current_hamiltonian);
396        let alpha = log_alpha.exp().min(1.0);
397
398        if log_u <= log_alpha {
399            Ok((new_position, alpha))
400        } else {
401            Ok((position, alpha))
402        }
403    }
404
405    /// Sample with adaptation
406    pub fn sample_adaptive<R: Rng + ?Sized>(
407        &mut self,
408        n_samples_: usize,
409        n_adapt: usize,
410        rng: &mut R,
411    ) -> Result<Array2<f64>> {
412        // Adaptation phase
413        for _ in 0..n_adapt {
414            self.step(rng)?;
415        }
416
417        // Reset counters
418        self.hmc.reset_counters();
419
420        // Sampling phase
421        let dim = self.hmc.position.len();
422        let mut samples = Array2::zeros((n_samples_, dim));
423
424        for i in 0..n_samples_ {
425            let sample = self.step(rng)?;
426            samples.row_mut(i).assign(&sample);
427        }
428
429        Ok(samples)
430    }
431}
432
433// Example implementations
434
435/// Multivariate normal target with gradient
436#[derive(Debug, Clone)]
437pub struct MultivariateNormalHMC {
438    /// Mean vector
439    pub mean: Array1<f64>,
440    /// Precision matrix
441    pub precision: Array2<f64>,
442    /// Log normalizing constant
443    pub log_norm_const: f64,
444}
445
446impl MultivariateNormalHMC {
447    /// Create new multivariate normal target
448    pub fn new(mean: Array1<f64>, covariance: Array2<f64>) -> Result<Self> {
449        checkarray_finite(&mean, "mean")?;
450        checkarray_finite(&covariance, "covariance")?;
451
452        if covariance.nrows() != mean.len() || covariance.ncols() != mean.len() {
453            return Err(StatsError::DimensionMismatch(format!(
454                "covariance shape ({}, {}) must be ({}, {})",
455                covariance.nrows(),
456                covariance.ncols(),
457                mean.len(),
458                mean.len()
459            )));
460        }
461
462        let precision = scirs2_linalg::inv(&covariance.view(), None).map_err(|e| {
463            StatsError::ComputationError(format!("Failed to invert covariance: {}", e))
464        })?;
465
466        let det = scirs2_linalg::det(&covariance.view(), None).map_err(|e| {
467            StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
468        })?;
469
470        if det <= 0.0 {
471            return Err(StatsError::InvalidArgument(
472                "Covariance must be positive definite".to_string(),
473            ));
474        }
475
476        let d = mean.len() as f64;
477        let log_norm_const = -0.5 * (d * (2.0 * std::f64::consts::PI).ln() + det.ln());
478
479        Ok(Self {
480            mean,
481            precision,
482            log_norm_const,
483        })
484    }
485}
486
487impl DifferentiableTarget for MultivariateNormalHMC {
488    fn log_density(&self, x: &Array1<f64>) -> f64 {
489        let diff = x - &self.mean;
490        let quad_form = diff.dot(&self.precision.dot(&diff));
491        self.log_norm_const - 0.5 * quad_form
492    }
493
494    fn gradient(&self, x: &Array1<f64>) -> Array1<f64> {
495        let diff = x - &self.mean;
496        -self.precision.dot(&diff)
497    }
498
499    fn dim(&self) -> usize {
500        self.mean.len()
501    }
502
503    fn log_density_and_gradient(&self, x: &Array1<f64>) -> (f64, Array1<f64>) {
504        let diff = x - &self.mean;
505        let quad_form = diff.dot(&self.precision.dot(&diff));
506        let log_density = self.log_norm_const - 0.5 * quad_form;
507        let gradient = -self.precision.dot(&diff);
508        (log_density, gradient)
509    }
510}
511
512/// Custom differentiable target from functions
513pub struct CustomDifferentiableTarget<F, G> {
514    /// Log density function
515    pub log_density_fn: F,
516    /// Gradient function
517    pub gradient_fn: G,
518    /// Dimensionality
519    pub dim: usize,
520}
521
522impl<F, G> CustomDifferentiableTarget<F, G> {
523    /// Create new custom target
524    pub fn new(dim: usize, log_density_fn: F, gradientfn: G) -> Result<Self> {
525        check_positive(dim, "dim")?;
526        Ok(Self {
527            log_density_fn,
528            gradient_fn: gradientfn,
529            dim,
530        })
531    }
532}
533
534impl<F, G> DifferentiableTarget for CustomDifferentiableTarget<F, G>
535where
536    F: Fn(&Array1<f64>) -> f64 + Send + Sync,
537    G: Fn(&Array1<f64>) -> Array1<f64> + Send + Sync,
538{
539    fn log_density(&self, x: &Array1<f64>) -> f64 {
540        (self.log_density_fn)(x)
541    }
542
543    fn gradient(&self, x: &Array1<f64>) -> Array1<f64> {
544        (self.gradient_fn)(x)
545    }
546
547    fn dim(&self) -> usize {
548        self.dim
549    }
550}