1use crate::error::{StatsError, StatsResult};
10use scirs2_core::ndarray::{Array1, Array2, ScalarOperand};
11use scirs2_core::numeric::{Float, NumAssign};
12use scirs2_core::random::{Distribution, Normal};
13use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
14use scirs2_core::{Rng, RngExt};
15use std::fmt::Display;
16use std::iter::Sum;
17use std::marker::PhantomData;
18
19pub trait EnhancedDifferentiableTarget<F>: Send + Sync
21where
22 F: Float + Copy + ScalarOperand + NumAssign + Display + Sum + Send + Sync,
23{
24 fn log_density(&self, x: &Array1<F>) -> F;
26
27 fn gradient(&self, x: &Array1<F>) -> Array1<F>;
29
30 fn dim(&self) -> usize;
32
33 fn log_density_and_gradient(&self, x: &Array1<F>) -> (F, Array1<F>) {
35 (self.log_density(x), self.gradient(x))
36 }
37
38 fn hessian(x: &Array1<F>) -> Option<Array2<F>> {
40 None
41 }
42
43 fn fisher_information(x: &Array1<F>) -> Option<Array2<F>> {
45 None
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct EnhancedHMCConfig {
52 pub initial_stepsize: f64,
54 pub num_leapfrog_steps: usize,
56 pub mass_adaptation: MassAdaptationStrategy,
58 pub stepsize_adaptation: StepSizeAdaptationStrategy,
60 pub parallel_leapfrog: bool,
62 pub use_simd: bool,
64 pub target_accept_rate: f64,
66 pub adaptation_steps: usize,
68 pub riemannian: bool,
70}
71
72impl Default for EnhancedHMCConfig {
73 fn default() -> Self {
74 Self {
75 initial_stepsize: 0.01,
76 num_leapfrog_steps: 10,
77 mass_adaptation: MassAdaptationStrategy::Identity,
78 stepsize_adaptation: StepSizeAdaptationStrategy::DualAveraging,
79 parallel_leapfrog: true,
80 use_simd: true,
81 target_accept_rate: 0.8,
82 adaptation_steps: 1000,
83 riemannian: false,
84 }
85 }
86}
87
88#[derive(Debug, Clone, PartialEq)]
90pub enum MassAdaptationStrategy {
91 Identity,
93 Diagonal,
95 Full,
97 Automatic,
99}
100
101#[derive(Debug, Clone, PartialEq)]
103pub enum StepSizeAdaptationStrategy {
104 Fixed,
106 DualAveraging,
108 Warmup,
110 Nesterov,
112}
113
114pub struct EnhancedHamiltonianMonteCarlo<T, F> {
116 pub target: T,
118 pub position: Array1<F>,
120 pub current_log_density: F,
122 pub config: EnhancedHMCConfig,
124 pub mass_matrix: Array2<F>,
126 pub mass_inv: Array2<F>,
128 pub stepsize: F,
130 pub adaptation_state: AdaptationState<F>,
132 pub stats: HMCStatistics,
134 _phantom: PhantomData<F>,
135}
136
137#[derive(Debug, Clone)]
139pub struct AdaptationState<F> {
140 pub iteration: usize,
142 pub stepsize_state: DualAveragingState,
144 pub mass_state: MassAdaptationState<F>,
146 pub sample_buffer: Vec<Array1<F>>,
148 pub buffersize: usize,
150}
151
152#[derive(Debug, Clone)]
154pub struct DualAveragingState {
155 pub log_step_avg: f64,
157 pub h_avg: f64,
159 pub target_accept: f64,
161 pub gamma: f64,
163 pub t0: f64,
165 pub kappa: f64,
167}
168
169#[derive(Debug, Clone)]
171pub struct MassAdaptationState<F> {
172 pub running_mean: Array1<F>,
174 pub running_cov: Array2<F>,
176 pub n_samples_: usize,
178}
179
180#[derive(Debug, Clone, Default)]
182pub struct HMCStatistics {
183 pub n_proposals: usize,
185 pub n_acceptances: usize,
187 pub avg_stepsize: f64,
189 pub avg_leapfrog_steps: f64,
191 pub energy_errors: Vec<f64>,
193}
194
195impl<T, F> EnhancedHamiltonianMonteCarlo<T, F>
196where
197 T: EnhancedDifferentiableTarget<F>,
198 F: Float
199 + Copy
200 + Send
201 + Sync
202 + SimdUnifiedOps
203 + ScalarOperand
204 + NumAssign
205 + Display
206 + Sum
207 + 'static,
208{
209 pub fn new(target: T, initial: Array1<F>, config: EnhancedHMCConfig) -> StatsResult<Self> {
211 checkarray_finite(&initial, "initial")?;
212
213 if initial.len() != target.dim() {
214 return Err(StatsError::DimensionMismatch(format!(
215 "Initial position dimension ({}) must match target dimension ({})",
216 initial.len(),
217 target.dim()
218 )));
219 }
220
221 let dim = initial.len();
222 let mass_matrix = Array2::eye(dim);
223 let mass_inv = Array2::eye(dim);
224 let current_log_density = target.log_density(&initial);
225 let stepsize = F::from(config.initial_stepsize).expect("Failed to convert to float");
226
227 let adaptation_state = AdaptationState {
228 iteration: 0,
229 stepsize_state: DualAveragingState {
230 log_step_avg: config.initial_stepsize.ln(),
231 h_avg: 0.0,
232 target_accept: config.target_accept_rate,
233 gamma: 0.05,
234 t0: 10.0,
235 kappa: 0.75,
236 },
237 mass_state: MassAdaptationState {
238 running_mean: Array1::zeros(dim),
239 running_cov: Array2::zeros((dim, dim)),
240 n_samples_: 0,
241 },
242 sample_buffer: Vec::new(),
243 buffersize: 100,
244 };
245
246 Ok(Self {
247 target,
248 position: initial,
249 current_log_density,
250 config,
251 mass_matrix,
252 mass_inv,
253 stepsize,
254 adaptation_state,
255 stats: HMCStatistics::default(),
256 _phantom: PhantomData,
257 })
258 }
259
260 pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> StatsResult<Array1<F>> {
262 let momentum = self.sample_momentum(rng)?;
264
265 let initial_position = self.position.clone();
267 let initial_momentum = momentum.clone();
268 let initial_log_density = self.current_log_density;
269
270 let (final_position, final_momentum) = if self.config.riemannian {
272 self.riemannian_leapfrog(initial_position.clone(), momentum)?
273 } else if self.config.parallel_leapfrog {
274 self.parallel_leapfrog(initial_position.clone(), momentum)?
275 } else {
276 self.standard_leapfrog(initial_position.clone(), momentum)?
277 };
278
279 let initial_hamiltonian = -initial_log_density + self.kinetic_energy(&initial_momentum);
281 let final_log_density = self.target.log_density(&final_position);
282 let final_hamiltonian = -final_log_density + self.kinetic_energy(&final_momentum);
283
284 let log_alpha = -(final_hamiltonian - initial_hamiltonian);
286 let alpha = log_alpha.exp().min(F::one());
287 let u: f64 = rng.random();
288
289 self.stats.n_proposals += 1;
290
291 let accepted = u < alpha.to_f64().expect("Operation failed");
292 if accepted {
293 self.position = final_position;
294 self.current_log_density = final_log_density;
295 self.stats.n_acceptances += 1;
296 }
297
298 if self.adaptation_state.iteration < self.config.adaptation_steps {
300 self.update_adaptation(alpha.to_f64().expect("Operation failed"))?;
301 }
302
303 self.stats.energy_errors.push(
305 (final_hamiltonian - initial_hamiltonian)
306 .to_f64()
307 .expect("Operation failed"),
308 );
309 if self.stats.energy_errors.len() > 1000 {
310 self.stats.energy_errors.drain(0..500); }
312
313 self.adaptation_state.iteration += 1;
314
315 Ok(self.position.clone())
316 }
317
318 fn standard_leapfrog(
320 &self,
321 mut position: Array1<F>,
322 mut momentum: Array1<F>,
323 ) -> StatsResult<(Array1<F>, Array1<F>)> {
324 let gradient = self.target.gradient(&position);
326 if self.config.use_simd && position.len() >= 4 {
327 let scaled_gradient = gradient.mapv(|g| {
328 g * self.stepsize * F::from(0.5).expect("Failed to convert constant to float")
329 });
330 momentum = F::simd_add(&momentum.view(), &scaled_gradient.view());
331 } else {
332 momentum = momentum
333 + gradient.mapv(|g| {
334 g * self.stepsize * F::from(0.5).expect("Failed to convert constant to float")
335 });
336 }
337
338 for _ in 0..self.config.num_leapfrog_steps {
340 let momentum_update = self.mass_inv.dot(&momentum);
342 if self.config.use_simd && position.len() >= 4 {
343 let scaled_momentum = momentum_update.mapv(|m| m * self.stepsize);
344 position = F::simd_add(&position.view(), &scaled_momentum.view());
345 } else {
346 position = position + momentum_update.mapv(|m| m * self.stepsize);
347 }
348
349 if self.config.num_leapfrog_steps > 1 {
351 let gradient = self.target.gradient(&position);
352 if self.config.use_simd && position.len() >= 4 {
353 let scaled_gradient = gradient.mapv(|g| g * self.stepsize);
354 momentum = F::simd_add(&momentum.view(), &scaled_gradient.view());
355 } else {
356 momentum = momentum + gradient.mapv(|g| g * self.stepsize);
357 }
358 }
359 }
360
361 let gradient = self.target.gradient(&position);
363 if self.config.use_simd && position.len() >= 4 {
364 let scaled_gradient = gradient.mapv(|g| {
365 g * self.stepsize * F::from(0.5).expect("Failed to convert constant to float")
366 });
367 momentum = F::simd_add(&momentum.view(), &scaled_gradient.view());
368 } else {
369 momentum = momentum
370 + gradient.mapv(|g| {
371 g * self.stepsize * F::from(0.5).expect("Failed to convert constant to float")
372 });
373 }
374
375 momentum = momentum.mapv(|m| -m);
377
378 Ok((position, momentum))
379 }
380
381 fn parallel_leapfrog(
383 &self,
384 position: Array1<F>,
385 momentum: Array1<F>,
386 ) -> StatsResult<(Array1<F>, Array1<F>)> {
387 self.standard_leapfrog(position, momentum)
390 }
391
392 fn riemannian_leapfrog(
394 &self,
395 mut position: Array1<F>,
396 mut momentum: Array1<F>,
397 ) -> StatsResult<(Array1<F>, Array1<F>)> {
398 for _ in 0..self.config.num_leapfrog_steps {
402 let gradient = self.target.gradient(&position);
404 let metric =
405 T::fisher_information(&position).unwrap_or_else(|| Array2::eye(position.len()));
406
407 let metric_inv = scirs2_linalg::inv(&metric.view(), None)
408 .unwrap_or_else(|_| Array2::eye(position.len()));
409
410 momentum = momentum
411 + gradient.mapv(|g| {
412 g * self.stepsize * F::from(0.5).expect("Failed to convert constant to float")
413 });
414
415 let velocity = metric_inv.dot(&momentum);
417 position = position + velocity.mapv(|v| v * self.stepsize);
418
419 let gradient = self.target.gradient(&position);
421 momentum = momentum
422 + gradient.mapv(|g| {
423 g * self.stepsize * F::from(0.5).expect("Failed to convert constant to float")
424 });
425 }
426
427 Ok((position, momentum))
428 }
429
430 fn sample_momentum<R: Rng + ?Sized>(&self, rng: &mut R) -> StatsResult<Array1<F>> {
432 let dim = self.position.len();
433 let normal = Normal::new(0.0, 1.0).map_err(|e| {
434 StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
435 })?;
436
437 let z: Vec<f64> = (0..dim).map(|_| normal.sample(rng)).collect();
439 let z_array = Array1::from_vec(
440 z.into_iter()
441 .map(|x| F::from(x).expect("Failed to convert to float"))
442 .collect(),
443 );
444
445 let mut momentum = Array1::zeros(dim);
448 for i in 0..dim {
449 momentum[i] = z_array[i] * self.mass_matrix[[i, i]].sqrt();
450 }
451
452 Ok(momentum)
453 }
454
455 fn kinetic_energy(&self, momentum: &Array1<F>) -> F {
457 let mut energy = F::zero();
458 for i in 0..momentum.len() {
459 energy += momentum[i] * momentum[i] * self.mass_inv[[i, i]];
460 }
461 energy * F::from(0.5).expect("Failed to convert constant to float")
462 }
463
464 fn update_adaptation(&mut self, alpha: f64) -> StatsResult<()> {
466 self.update_stepsize_adaptation(alpha);
468
469 self.update_mass_adaptation()?;
471
472 Ok(())
473 }
474
475 fn update_stepsize_adaptation(&mut self, alpha: f64) {
477 let state = &mut self.adaptation_state.stepsize_state;
478 let m = self.adaptation_state.iteration as f64 + 1.0;
479
480 state.h_avg = (1.0 - 1.0 / (m + state.t0)) * state.h_avg
482 + (state.target_accept - alpha) / (m + state.t0);
483
484 let log_step = state.log_step_avg - state.h_avg / (state.gamma * m.powf(state.kappa));
486
487 let weight = m.powf(-state.kappa);
489 state.log_step_avg = (1.0 - weight) * state.log_step_avg + weight * log_step;
490
491 self.stepsize = F::from(log_step.exp()).expect("Operation failed");
493 }
494
495 fn update_mass_adaptation(&mut self) -> StatsResult<()> {
497 let state = &mut self.adaptation_state.mass_state;
498
499 self.adaptation_state
501 .sample_buffer
502 .push(self.position.clone());
503 if self.adaptation_state.sample_buffer.len() > self.adaptation_state.buffersize {
504 self.adaptation_state.sample_buffer.drain(0..1);
505 }
506
507 state.n_samples_ += 1;
509 let n = state.n_samples_ as f64;
510
511 let delta = &self.position - &state.running_mean;
513 state.running_mean = &state.running_mean
514 + &delta.mapv(|d| d / F::from(n).expect("Failed to convert to float"));
515
516 match self.config.mass_adaptation {
518 MassAdaptationStrategy::Identity => {
519 }
521 MassAdaptationStrategy::Diagonal => {
522 if self.adaptation_state.sample_buffer.len() > 10 {
524 let variance = self.compute_sample_variance()?;
525 for i in 0..self.mass_matrix.nrows() {
526 self.mass_matrix[[i, i]] = variance[i];
527 self.mass_inv[[i, i]] = F::one() / variance[i];
528 }
529 }
530 }
531 MassAdaptationStrategy::Full => {
532 if self.adaptation_state.sample_buffer.len() > 20 {
534 let covariance = self.compute_sample_covariance()?;
535 self.mass_matrix = covariance.clone();
536 self.mass_inv = scirs2_linalg::inv(&covariance.view(), None)
537 .unwrap_or_else(|_| Array2::eye(self.position.len()));
538 }
539 }
540 MassAdaptationStrategy::Automatic => {
541 if self.position.len() <= 50 {
543 if self.adaptation_state.sample_buffer.len() > 20 {
545 let covariance = self.compute_sample_covariance()?;
546 self.mass_matrix = covariance.clone();
547 self.mass_inv = scirs2_linalg::inv(&covariance.view(), None)
548 .unwrap_or_else(|_| Array2::eye(self.position.len()));
549 }
550 } else {
551 if self.adaptation_state.sample_buffer.len() > 10 {
553 let variance = self.compute_sample_variance()?;
554 for i in 0..self.mass_matrix.nrows() {
555 self.mass_matrix[[i, i]] = variance[i];
556 self.mass_inv[[i, i]] = F::one() / variance[i];
557 }
558 }
559 }
560 }
561 }
562
563 Ok(())
564 }
565
566 fn compute_sample_variance(&self) -> StatsResult<Array1<F>> {
568 let buffer = &self.adaptation_state.sample_buffer;
569 if buffer.is_empty() {
570 return Ok(Array1::ones(self.position.len()));
571 }
572
573 let n = buffer.len();
574 let mean = buffer
575 .iter()
576 .fold(Array1::zeros(self.position.len()), |acc, x| acc + x)
577 / F::from(n).expect("Failed to convert to float");
578
579 let variance = buffer
580 .iter()
581 .map(|x| (x - &mean).mapv(|d| d * d))
582 .fold(Array1::zeros(self.position.len()), |acc, x| acc + x)
583 / F::from(n.saturating_sub(1).max(1)).expect("Operation failed");
584
585 Ok(
586 variance
587 .mapv(|v: F| v.max(F::from(1e-6).expect("Failed to convert constant to float"))),
588 ) }
590
591 fn compute_sample_covariance(&self) -> StatsResult<Array2<F>> {
593 let buffer = &self.adaptation_state.sample_buffer;
594 if buffer.is_empty() {
595 return Ok(Array2::eye(self.position.len()));
596 }
597
598 let n = buffer.len();
599 let dim = self.position.len();
600 let mean = buffer.iter().fold(Array1::zeros(dim), |acc, x| acc + x)
601 / F::from(n).expect("Failed to convert to float");
602
603 let mut covariance = Array2::zeros((dim, dim));
604 for sample in buffer {
605 let centered = sample - &mean;
606 for i in 0..dim {
607 for j in 0..dim {
608 covariance[[i, j]] += centered[i] * centered[j];
609 }
610 }
611 }
612
613 covariance /= F::from(n.saturating_sub(1).max(1)).expect("Operation failed");
614
615 for i in 0..dim {
617 covariance[[i, i]] += F::from(1e-6).expect("Failed to convert constant to float");
618 }
619
620 Ok(covariance)
621 }
622
623 pub fn acceptance_rate(&self) -> f64 {
625 if self.stats.n_proposals == 0 {
626 0.0
627 } else {
628 self.stats.n_acceptances as f64 / self.stats.n_proposals as f64
629 }
630 }
631
632 pub fn sample_adaptive<R: Rng + ?Sized>(
634 &mut self,
635 n_samples_: usize,
636 rng: &mut R,
637 ) -> StatsResult<Array2<F>> {
638 let dim = self.position.len();
639 let mut samples = Array2::zeros((n_samples_, dim));
640
641 for i in 0..n_samples_ {
642 let sample = self.step(rng)?;
643 samples.row_mut(i).assign(&sample);
644 }
645
646 Ok(samples)
647 }
648}
649
650#[allow(dead_code)]
652pub fn enhanced_hmc_sample<T, F, R>(
653 target: T,
654 initial: Array1<F>,
655 n_samples_: usize,
656 config: Option<EnhancedHMCConfig>,
657 rng: &mut R,
658) -> StatsResult<Array2<F>>
659where
660 T: EnhancedDifferentiableTarget<F>,
661 F: Float
662 + Copy
663 + Send
664 + Sync
665 + SimdUnifiedOps
666 + ScalarOperand
667 + NumAssign
668 + Display
669 + Sum
670 + 'static,
671 R: Rng + ?Sized,
672{
673 let config = config.unwrap_or_default();
674 let mut sampler = EnhancedHamiltonianMonteCarlo::new(target, initial, config)?;
675 sampler.sample_adaptive(n_samples_, rng)
676}