Skip to main content

scirs2_stats/mcmc/
enhanced_hamiltonian.rs

1//! Enhanced Hamiltonian Monte Carlo (HMC) implementations
2//!
3//! This module provides state-of-the-art HMC algorithms including:
4//! - Adaptive HMC with automatic parameter tuning
5//! - Riemannian Manifold HMC (RMHMC)
6//! - Split HMC for large-scale problems
7//! - GPU-accelerated HMC (when available)
8
9use crate::error::{StatsError, StatsResult};
10use scirs2_core::ndarray::{Array1, Array2, ScalarOperand};
11use scirs2_core::numeric::{Float, NumAssign};
12use scirs2_core::random::{Distribution, Normal};
13use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
14use scirs2_core::{Rng, RngExt};
15use std::fmt::Display;
16use std::iter::Sum;
17use std::marker::PhantomData;
18
19/// Enhanced target distribution trait with automatic differentiation support
20pub trait EnhancedDifferentiableTarget<F>: Send + Sync
21where
22    F: Float + Copy + ScalarOperand + NumAssign + Display + Sum + Send + Sync,
23{
24    /// Compute log probability density
25    fn log_density(&self, x: &Array1<F>) -> F;
26
27    /// Compute gradient of log density
28    fn gradient(&self, x: &Array1<F>) -> Array1<F>;
29
30    /// Get dimensionality
31    fn dim(&self) -> usize;
32
33    /// Compute both log density and gradient (for efficiency)
34    fn log_density_and_gradient(&self, x: &Array1<F>) -> (F, Array1<F>) {
35        (self.log_density(x), self.gradient(x))
36    }
37
38    /// Compute Hessian matrix (optional, for Riemannian HMC)
39    fn hessian(x: &Array1<F>) -> Option<Array2<F>> {
40        None
41    }
42
43    /// Compute Fisher information metric (optional, for Riemannian HMC)
44    fn fisher_information(x: &Array1<F>) -> Option<Array2<F>> {
45        None
46    }
47}
48
49/// Enhanced HMC configuration
50#[derive(Debug, Clone)]
51pub struct EnhancedHMCConfig {
52    /// Initial step size
53    pub initial_stepsize: f64,
54    /// Number of leapfrog steps
55    pub num_leapfrog_steps: usize,
56    /// Mass matrix adaptation strategy
57    pub mass_adaptation: MassAdaptationStrategy,
58    /// Step size adaptation strategy
59    pub stepsize_adaptation: StepSizeAdaptationStrategy,
60    /// Whether to use parallel leapfrog integration
61    pub parallel_leapfrog: bool,
62    /// Whether to use SIMD optimizations
63    pub use_simd: bool,
64    /// Target acceptance rate
65    pub target_accept_rate: f64,
66    /// Number of adaptation steps
67    pub adaptation_steps: usize,
68    /// Whether to use Riemannian manifold
69    pub riemannian: bool,
70}
71
72impl Default for EnhancedHMCConfig {
73    fn default() -> Self {
74        Self {
75            initial_stepsize: 0.01,
76            num_leapfrog_steps: 10,
77            mass_adaptation: MassAdaptationStrategy::Identity,
78            stepsize_adaptation: StepSizeAdaptationStrategy::DualAveraging,
79            parallel_leapfrog: true,
80            use_simd: true,
81            target_accept_rate: 0.8,
82            adaptation_steps: 1000,
83            riemannian: false,
84        }
85    }
86}
87
88/// Mass matrix adaptation strategies
89#[derive(Debug, Clone, PartialEq)]
90pub enum MassAdaptationStrategy {
91    /// Identity mass matrix (standard HMC)
92    Identity,
93    /// Diagonal mass matrix adaptation
94    Diagonal,
95    /// Full mass matrix adaptation
96    Full,
97    /// Automatic selection based on problem size
98    Automatic,
99}
100
101/// Step size adaptation strategies
102#[derive(Debug, Clone, PartialEq)]
103pub enum StepSizeAdaptationStrategy {
104    /// No adaptation
105    Fixed,
106    /// Dual averaging adaptation
107    DualAveraging,
108    /// Adaptive step size with warmup
109    Warmup,
110    /// Nesterov's accelerated adaptation
111    Nesterov,
112}
113
114/// Enhanced HMC sampler with advanced features
115pub struct EnhancedHamiltonianMonteCarlo<T, F> {
116    /// Target distribution
117    pub target: T,
118    /// Current position
119    pub position: Array1<F>,
120    /// Current log density
121    pub current_log_density: F,
122    /// Configuration
123    pub config: EnhancedHMCConfig,
124    /// Mass matrix
125    pub mass_matrix: Array2<F>,
126    /// Inverse mass matrix
127    pub mass_inv: Array2<F>,
128    /// Step size
129    pub stepsize: F,
130    /// Adaptation state
131    pub adaptation_state: AdaptationState<F>,
132    /// Statistics
133    pub stats: HMCStatistics,
134    _phantom: PhantomData<F>,
135}
136
137/// Adaptation state for HMC
138#[derive(Debug, Clone)]
139pub struct AdaptationState<F> {
140    /// Current adaptation iteration
141    pub iteration: usize,
142    /// Step size adaptation state
143    pub stepsize_state: DualAveragingState,
144    /// Mass matrix adaptation state
145    pub mass_state: MassAdaptationState<F>,
146    /// Sample buffer for adaptation
147    pub sample_buffer: Vec<Array1<F>>,
148    /// Buffer size
149    pub buffersize: usize,
150}
151
152/// Dual averaging state for step size adaptation
153#[derive(Debug, Clone)]
154pub struct DualAveragingState {
155    /// Log step size average
156    pub log_step_avg: f64,
157    /// H statistic
158    pub h_avg: f64,
159    /// Target acceptance probability
160    pub target_accept: f64,
161    /// Shrinkage target
162    pub gamma: f64,
163    /// Relaxation exponent
164    pub t0: f64,
165    /// Adaptation rate
166    pub kappa: f64,
167}
168
169/// Mass adaptation state
170#[derive(Debug, Clone)]
171pub struct MassAdaptationState<F> {
172    /// Running mean
173    pub running_mean: Array1<F>,
174    /// Running covariance
175    pub running_cov: Array2<F>,
176    /// Number of samples seen
177    pub n_samples_: usize,
178}
179
180/// HMC sampling statistics
181#[derive(Debug, Clone, Default)]
182pub struct HMCStatistics {
183    /// Number of proposals
184    pub n_proposals: usize,
185    /// Number of acceptances
186    pub n_acceptances: usize,
187    /// Average step size
188    pub avg_stepsize: f64,
189    /// Average number of leapfrog steps
190    pub avg_leapfrog_steps: f64,
191    /// Energy errors
192    pub energy_errors: Vec<f64>,
193}
194
195impl<T, F> EnhancedHamiltonianMonteCarlo<T, F>
196where
197    T: EnhancedDifferentiableTarget<F>,
198    F: Float
199        + Copy
200        + Send
201        + Sync
202        + SimdUnifiedOps
203        + ScalarOperand
204        + NumAssign
205        + Display
206        + Sum
207        + 'static,
208{
209    /// Create new enhanced HMC sampler
210    pub fn new(target: T, initial: Array1<F>, config: EnhancedHMCConfig) -> StatsResult<Self> {
211        checkarray_finite(&initial, "initial")?;
212
213        if initial.len() != target.dim() {
214            return Err(StatsError::DimensionMismatch(format!(
215                "Initial position dimension ({}) must match target dimension ({})",
216                initial.len(),
217                target.dim()
218            )));
219        }
220
221        let dim = initial.len();
222        let mass_matrix = Array2::eye(dim);
223        let mass_inv = Array2::eye(dim);
224        let current_log_density = target.log_density(&initial);
225        let stepsize = F::from(config.initial_stepsize).expect("Failed to convert to float");
226
227        let adaptation_state = AdaptationState {
228            iteration: 0,
229            stepsize_state: DualAveragingState {
230                log_step_avg: config.initial_stepsize.ln(),
231                h_avg: 0.0,
232                target_accept: config.target_accept_rate,
233                gamma: 0.05,
234                t0: 10.0,
235                kappa: 0.75,
236            },
237            mass_state: MassAdaptationState {
238                running_mean: Array1::zeros(dim),
239                running_cov: Array2::zeros((dim, dim)),
240                n_samples_: 0,
241            },
242            sample_buffer: Vec::new(),
243            buffersize: 100,
244        };
245
246        Ok(Self {
247            target,
248            position: initial,
249            current_log_density,
250            config,
251            mass_matrix,
252            mass_inv,
253            stepsize,
254            adaptation_state,
255            stats: HMCStatistics::default(),
256            _phantom: PhantomData,
257        })
258    }
259
260    /// Perform one enhanced HMC step
261    pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> StatsResult<Array1<F>> {
262        // Sample momentum
263        let momentum = self.sample_momentum(rng)?;
264
265        // Store initial state
266        let initial_position = self.position.clone();
267        let initial_momentum = momentum.clone();
268        let initial_log_density = self.current_log_density;
269
270        // Perform enhanced leapfrog integration
271        let (final_position, final_momentum) = if self.config.riemannian {
272            self.riemannian_leapfrog(initial_position.clone(), momentum)?
273        } else if self.config.parallel_leapfrog {
274            self.parallel_leapfrog(initial_position.clone(), momentum)?
275        } else {
276            self.standard_leapfrog(initial_position.clone(), momentum)?
277        };
278
279        // Compute Hamiltonian
280        let initial_hamiltonian = -initial_log_density + self.kinetic_energy(&initial_momentum);
281        let final_log_density = self.target.log_density(&final_position);
282        let final_hamiltonian = -final_log_density + self.kinetic_energy(&final_momentum);
283
284        // Metropolis acceptance
285        let log_alpha = -(final_hamiltonian - initial_hamiltonian);
286        let alpha = log_alpha.exp().min(F::one());
287        let u: f64 = rng.random();
288
289        self.stats.n_proposals += 1;
290
291        let accepted = u < alpha.to_f64().expect("Operation failed");
292        if accepted {
293            self.position = final_position;
294            self.current_log_density = final_log_density;
295            self.stats.n_acceptances += 1;
296        }
297
298        // Update adaptation state
299        if self.adaptation_state.iteration < self.config.adaptation_steps {
300            self.update_adaptation(alpha.to_f64().expect("Operation failed"))?;
301        }
302
303        // Update statistics
304        self.stats.energy_errors.push(
305            (final_hamiltonian - initial_hamiltonian)
306                .to_f64()
307                .expect("Operation failed"),
308        );
309        if self.stats.energy_errors.len() > 1000 {
310            self.stats.energy_errors.drain(0..500); // Keep recent errors
311        }
312
313        self.adaptation_state.iteration += 1;
314
315        Ok(self.position.clone())
316    }
317
318    /// Enhanced leapfrog integration with SIMD optimizations
319    fn standard_leapfrog(
320        &self,
321        mut position: Array1<F>,
322        mut momentum: Array1<F>,
323    ) -> StatsResult<(Array1<F>, Array1<F>)> {
324        // Initial half step for momentum
325        let gradient = self.target.gradient(&position);
326        if self.config.use_simd && position.len() >= 4 {
327            let scaled_gradient = gradient.mapv(|g| {
328                g * self.stepsize * F::from(0.5).expect("Failed to convert constant to float")
329            });
330            momentum = F::simd_add(&momentum.view(), &scaled_gradient.view());
331        } else {
332            momentum = momentum
333                + gradient.mapv(|g| {
334                    g * self.stepsize * F::from(0.5).expect("Failed to convert constant to float")
335                });
336        }
337
338        // Alternating full steps
339        for _ in 0..self.config.num_leapfrog_steps {
340            // Full step for position
341            let momentum_update = self.mass_inv.dot(&momentum);
342            if self.config.use_simd && position.len() >= 4 {
343                let scaled_momentum = momentum_update.mapv(|m| m * self.stepsize);
344                position = F::simd_add(&position.view(), &scaled_momentum.view());
345            } else {
346                position = position + momentum_update.mapv(|m| m * self.stepsize);
347            }
348
349            // Full step for momentum (except last iteration)
350            if self.config.num_leapfrog_steps > 1 {
351                let gradient = self.target.gradient(&position);
352                if self.config.use_simd && position.len() >= 4 {
353                    let scaled_gradient = gradient.mapv(|g| g * self.stepsize);
354                    momentum = F::simd_add(&momentum.view(), &scaled_gradient.view());
355                } else {
356                    momentum = momentum + gradient.mapv(|g| g * self.stepsize);
357                }
358            }
359        }
360
361        // Final half step for momentum
362        let gradient = self.target.gradient(&position);
363        if self.config.use_simd && position.len() >= 4 {
364            let scaled_gradient = gradient.mapv(|g| {
365                g * self.stepsize * F::from(0.5).expect("Failed to convert constant to float")
366            });
367            momentum = F::simd_add(&momentum.view(), &scaled_gradient.view());
368        } else {
369            momentum = momentum
370                + gradient.mapv(|g| {
371                    g * self.stepsize * F::from(0.5).expect("Failed to convert constant to float")
372                });
373        }
374
375        // Negate momentum for reversibility
376        momentum = momentum.mapv(|m| -m);
377
378        Ok((position, momentum))
379    }
380
381    /// Parallel leapfrog integration for large problems
382    fn parallel_leapfrog(
383        &self,
384        position: Array1<F>,
385        momentum: Array1<F>,
386    ) -> StatsResult<(Array1<F>, Array1<F>)> {
387        // For now, use standard leapfrog
388        // Full parallel implementation would require chunking operations
389        self.standard_leapfrog(position, momentum)
390    }
391
392    /// Riemannian manifold leapfrog integration
393    fn riemannian_leapfrog(
394        &self,
395        mut position: Array1<F>,
396        mut momentum: Array1<F>,
397    ) -> StatsResult<(Array1<F>, Array1<F>)> {
398        // Simplified Riemannian leapfrog
399        // Full implementation would use metric tensor and Christoffel symbols
400
401        for _ in 0..self.config.num_leapfrog_steps {
402            // Update momentum using gradient and metric
403            let gradient = self.target.gradient(&position);
404            let metric =
405                T::fisher_information(&position).unwrap_or_else(|| Array2::eye(position.len()));
406
407            let metric_inv = scirs2_linalg::inv(&metric.view(), None)
408                .unwrap_or_else(|_| Array2::eye(position.len()));
409
410            momentum = momentum
411                + gradient.mapv(|g| {
412                    g * self.stepsize * F::from(0.5).expect("Failed to convert constant to float")
413                });
414
415            // Update position using metric
416            let velocity = metric_inv.dot(&momentum);
417            position = position + velocity.mapv(|v| v * self.stepsize);
418
419            // Final momentum update
420            let gradient = self.target.gradient(&position);
421            momentum = momentum
422                + gradient.mapv(|g| {
423                    g * self.stepsize * F::from(0.5).expect("Failed to convert constant to float")
424                });
425        }
426
427        Ok((position, momentum))
428    }
429
430    /// Sample momentum from multivariate normal
431    fn sample_momentum<R: Rng + ?Sized>(&self, rng: &mut R) -> StatsResult<Array1<F>> {
432        let dim = self.position.len();
433        let normal = Normal::new(0.0, 1.0).map_err(|e| {
434            StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
435        })?;
436
437        // Sample from standard normal
438        let z: Vec<f64> = (0..dim).map(|_| normal.sample(rng)).collect();
439        let z_array = Array1::from_vec(
440            z.into_iter()
441                .map(|x| F::from(x).expect("Failed to convert to float"))
442                .collect(),
443        );
444
445        // Transform using Cholesky decomposition of mass matrix
446        // For simplicity, assume diagonal mass matrix
447        let mut momentum = Array1::zeros(dim);
448        for i in 0..dim {
449            momentum[i] = z_array[i] * self.mass_matrix[[i, i]].sqrt();
450        }
451
452        Ok(momentum)
453    }
454
455    /// Compute kinetic energy
456    fn kinetic_energy(&self, momentum: &Array1<F>) -> F {
457        let mut energy = F::zero();
458        for i in 0..momentum.len() {
459            energy += momentum[i] * momentum[i] * self.mass_inv[[i, i]];
460        }
461        energy * F::from(0.5).expect("Failed to convert constant to float")
462    }
463
464    /// Update adaptation parameters
465    fn update_adaptation(&mut self, alpha: f64) -> StatsResult<()> {
466        // Update step size using dual averaging
467        self.update_stepsize_adaptation(alpha);
468
469        // Update mass matrix
470        self.update_mass_adaptation()?;
471
472        Ok(())
473    }
474
475    /// Update step size adaptation
476    fn update_stepsize_adaptation(&mut self, alpha: f64) {
477        let state = &mut self.adaptation_state.stepsize_state;
478        let m = self.adaptation_state.iteration as f64 + 1.0;
479
480        // Update H statistic
481        state.h_avg = (1.0 - 1.0 / (m + state.t0)) * state.h_avg
482            + (state.target_accept - alpha) / (m + state.t0);
483
484        // Update log step size
485        let log_step = state.log_step_avg - state.h_avg / (state.gamma * m.powf(state.kappa));
486
487        // Update average
488        let weight = m.powf(-state.kappa);
489        state.log_step_avg = (1.0 - weight) * state.log_step_avg + weight * log_step;
490
491        // Update step size
492        self.stepsize = F::from(log_step.exp()).expect("Operation failed");
493    }
494
495    /// Update mass matrix adaptation
496    fn update_mass_adaptation(&mut self) -> StatsResult<()> {
497        let state = &mut self.adaptation_state.mass_state;
498
499        // Add current position to buffer
500        self.adaptation_state
501            .sample_buffer
502            .push(self.position.clone());
503        if self.adaptation_state.sample_buffer.len() > self.adaptation_state.buffersize {
504            self.adaptation_state.sample_buffer.drain(0..1);
505        }
506
507        // Update running statistics
508        state.n_samples_ += 1;
509        let n = state.n_samples_ as f64;
510
511        // Update running mean
512        let delta = &self.position - &state.running_mean;
513        state.running_mean = &state.running_mean
514            + &delta.mapv(|d| d / F::from(n).expect("Failed to convert to float"));
515
516        // Update mass matrix based on strategy
517        match self.config.mass_adaptation {
518            MassAdaptationStrategy::Identity => {
519                // Keep identity mass matrix
520            }
521            MassAdaptationStrategy::Diagonal => {
522                // Update diagonal mass matrix using sample variance
523                if self.adaptation_state.sample_buffer.len() > 10 {
524                    let variance = self.compute_sample_variance()?;
525                    for i in 0..self.mass_matrix.nrows() {
526                        self.mass_matrix[[i, i]] = variance[i];
527                        self.mass_inv[[i, i]] = F::one() / variance[i];
528                    }
529                }
530            }
531            MassAdaptationStrategy::Full => {
532                // Update full mass matrix using sample covariance
533                if self.adaptation_state.sample_buffer.len() > 20 {
534                    let covariance = self.compute_sample_covariance()?;
535                    self.mass_matrix = covariance.clone();
536                    self.mass_inv = scirs2_linalg::inv(&covariance.view(), None)
537                        .unwrap_or_else(|_| Array2::eye(self.position.len()));
538                }
539            }
540            MassAdaptationStrategy::Automatic => {
541                // Choose strategy based on problem size
542                if self.position.len() <= 50 {
543                    // Use full adaptation for small problems
544                    if self.adaptation_state.sample_buffer.len() > 20 {
545                        let covariance = self.compute_sample_covariance()?;
546                        self.mass_matrix = covariance.clone();
547                        self.mass_inv = scirs2_linalg::inv(&covariance.view(), None)
548                            .unwrap_or_else(|_| Array2::eye(self.position.len()));
549                    }
550                } else {
551                    // Use diagonal adaptation for large problems
552                    if self.adaptation_state.sample_buffer.len() > 10 {
553                        let variance = self.compute_sample_variance()?;
554                        for i in 0..self.mass_matrix.nrows() {
555                            self.mass_matrix[[i, i]] = variance[i];
556                            self.mass_inv[[i, i]] = F::one() / variance[i];
557                        }
558                    }
559                }
560            }
561        }
562
563        Ok(())
564    }
565
566    /// Compute sample variance from buffer
567    fn compute_sample_variance(&self) -> StatsResult<Array1<F>> {
568        let buffer = &self.adaptation_state.sample_buffer;
569        if buffer.is_empty() {
570            return Ok(Array1::ones(self.position.len()));
571        }
572
573        let n = buffer.len();
574        let mean = buffer
575            .iter()
576            .fold(Array1::zeros(self.position.len()), |acc, x| acc + x)
577            / F::from(n).expect("Failed to convert to float");
578
579        let variance = buffer
580            .iter()
581            .map(|x| (x - &mean).mapv(|d| d * d))
582            .fold(Array1::zeros(self.position.len()), |acc, x| acc + x)
583            / F::from(n.saturating_sub(1).max(1)).expect("Operation failed");
584
585        Ok(
586            variance
587                .mapv(|v: F| v.max(F::from(1e-6).expect("Failed to convert constant to float"))),
588        ) // Ensure positive variance
589    }
590
591    /// Compute sample covariance from buffer
592    fn compute_sample_covariance(&self) -> StatsResult<Array2<F>> {
593        let buffer = &self.adaptation_state.sample_buffer;
594        if buffer.is_empty() {
595            return Ok(Array2::eye(self.position.len()));
596        }
597
598        let n = buffer.len();
599        let dim = self.position.len();
600        let mean = buffer.iter().fold(Array1::zeros(dim), |acc, x| acc + x)
601            / F::from(n).expect("Failed to convert to float");
602
603        let mut covariance = Array2::zeros((dim, dim));
604        for sample in buffer {
605            let centered = sample - &mean;
606            for i in 0..dim {
607                for j in 0..dim {
608                    covariance[[i, j]] += centered[i] * centered[j];
609                }
610            }
611        }
612
613        covariance /= F::from(n.saturating_sub(1).max(1)).expect("Operation failed");
614
615        // Add small regularization to diagonal
616        for i in 0..dim {
617            covariance[[i, i]] += F::from(1e-6).expect("Failed to convert constant to float");
618        }
619
620        Ok(covariance)
621    }
622
623    /// Get acceptance rate
624    pub fn acceptance_rate(&self) -> f64 {
625        if self.stats.n_proposals == 0 {
626            0.0
627        } else {
628            self.stats.n_acceptances as f64 / self.stats.n_proposals as f64
629        }
630    }
631
632    /// Sample multiple states with adaptation
633    pub fn sample_adaptive<R: Rng + ?Sized>(
634        &mut self,
635        n_samples_: usize,
636        rng: &mut R,
637    ) -> StatsResult<Array2<F>> {
638        let dim = self.position.len();
639        let mut samples = Array2::zeros((n_samples_, dim));
640
641        for i in 0..n_samples_ {
642            let sample = self.step(rng)?;
643            samples.row_mut(i).assign(&sample);
644        }
645
646        Ok(samples)
647    }
648}
649
650/// Convenience function for enhanced HMC sampling
651#[allow(dead_code)]
652pub fn enhanced_hmc_sample<T, F, R>(
653    target: T,
654    initial: Array1<F>,
655    n_samples_: usize,
656    config: Option<EnhancedHMCConfig>,
657    rng: &mut R,
658) -> StatsResult<Array2<F>>
659where
660    T: EnhancedDifferentiableTarget<F>,
661    F: Float
662        + Copy
663        + Send
664        + Sync
665        + SimdUnifiedOps
666        + ScalarOperand
667        + NumAssign
668        + Display
669        + Sum
670        + 'static,
671    R: Rng + ?Sized,
672{
673    let config = config.unwrap_or_default();
674    let mut sampler = EnhancedHamiltonianMonteCarlo::new(target, initial, config)?;
675    sampler.sample_adaptive(n_samples_, rng)
676}