1use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::{Distribution, Normal};
9use scirs2_core::validation::*;
10use scirs2_core::{Rng, RngExt};
11use std::fmt::Debug;
12
13pub trait DifferentiableTarget: Send + Sync {
15 fn log_density(&self, x: &Array1<f64>) -> f64;
17
18 fn gradient(&self, x: &Array1<f64>) -> Array1<f64>;
20
21 fn dim(&self) -> usize;
23
24 fn log_density_and_gradient(&self, x: &Array1<f64>) -> (f64, Array1<f64>) {
26 (self.log_density(x), self.gradient(x))
27 }
28}
29
30pub struct HamiltonianMonteCarlo<T: DifferentiableTarget> {
32 pub target: T,
34 pub position: Array1<f64>,
36 pub current_log_density: f64,
38 pub stepsize: f64,
40 pub n_steps: usize,
42 pub mass_matrix: Array2<f64>,
44 pub mass_inv: Array2<f64>,
46 pub n_accepted: usize,
48 pub n_proposed: usize,
50}
51
52impl<T: DifferentiableTarget> HamiltonianMonteCarlo<T> {
53 pub fn new(target: T, initial: Array1<f64>, stepsize: f64, nsteps: usize) -> Result<Self> {
55 checkarray_finite(&initial, "initial")?;
56 check_positive(stepsize, "stepsize")?;
57 check_positive(nsteps, "nsteps")?;
58
59 if initial.len() != target.dim() {
60 return Err(StatsError::DimensionMismatch(format!(
61 "initial dimension ({}) must match target dimension ({})",
62 initial.len(),
63 target.dim()
64 )));
65 }
66
67 let dim = initial.len();
68 let mass_matrix = Array2::eye(dim);
69 let mass_inv = Array2::eye(dim);
70 let current_log_density = target.log_density(&initial);
71
72 Ok(Self {
73 target,
74 position: initial,
75 current_log_density,
76 stepsize,
77 n_steps: nsteps,
78 mass_matrix,
79 mass_inv,
80 n_accepted: 0,
81 n_proposed: 0,
82 })
83 }
84
85 pub fn with_mass_matrix(mut self, massmatrix: Array2<f64>) -> Result<Self> {
87 checkarray_finite(&massmatrix, "massmatrix")?;
88
89 if massmatrix.nrows() != self.position.len() || massmatrix.ncols() != self.position.len() {
90 return Err(StatsError::DimensionMismatch(format!(
91 "massmatrix shape ({}, {}) must be ({}, {})",
92 massmatrix.nrows(),
93 massmatrix.ncols(),
94 self.position.len(),
95 self.position.len()
96 )));
97 }
98
99 let mass_inv = scirs2_linalg::inv(&massmatrix.view(), None).map_err(|e| {
101 StatsError::ComputationError(format!("Failed to invert mass matrix: {}", e))
102 })?;
103
104 self.mass_matrix = massmatrix;
105 self.mass_inv = mass_inv;
106 Ok(self)
107 }
108
109 pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<Array1<f64>> {
111 let _dim = self.position.len();
112
113 let momentum = self.sample_momentum(rng)?;
115
116 let initial_position = self.position.clone();
118 let initial_momentum = momentum.clone();
119 let initial_log_density = self.current_log_density;
120
121 let (final_position, final_momentum) = self.leapfrog(initial_position.clone(), momentum)?;
123
124 let initial_hamiltonian =
126 -initial_log_density + 0.5 * self.kinetic_energy(&initial_momentum);
127 let final_log_density = self.target.log_density(&final_position);
128 let final_hamiltonian = -final_log_density + 0.5 * self.kinetic_energy(&final_momentum);
129
130 let log_alpha = -(final_hamiltonian - initial_hamiltonian);
132 let u: f64 = rng.random();
133
134 self.n_proposed += 1;
135
136 if u.ln() < log_alpha {
137 self.position = final_position;
139 self.current_log_density = final_log_density;
140 self.n_accepted += 1;
141 }
142 Ok(self.position.clone())
145 }
146
147 fn sample_momentum<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<Array1<f64>> {
149 let dim = self.position.len();
150 let normal = Normal::new(0.0, 1.0).map_err(|e| {
151 StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
152 })?;
153
154 let z = Array1::from_shape_fn(dim, |_| normal.sample(rng));
156
157 let mut momentum = Array1::zeros(dim);
160 for i in 0..dim {
161 momentum[i] = z[i] * self.mass_matrix[[i, i]].sqrt();
162 }
163
164 Ok(momentum)
165 }
166
167 fn kinetic_energy(&self, momentum: &Array1<f64>) -> f64 {
169 let mut energy = 0.0;
171 for i in 0..momentum.len() {
172 energy += momentum[i] * momentum[i] * self.mass_inv[[i, i]];
173 }
174 0.5 * energy
175 }
176
177 fn leapfrog(
179 &self,
180 mut position: Array1<f64>,
181 mut momentum: Array1<f64>,
182 ) -> Result<(Array1<f64>, Array1<f64>)> {
183 let gradient = self.target.gradient(&position);
185 momentum = momentum + 0.5 * self.stepsize * gradient;
186
187 for _ in 0..self.n_steps {
189 let momentum_update = self.mass_inv.dot(&momentum);
191 position = position + self.stepsize * momentum_update;
192
193 if self.n_steps > 1 {
195 let gradient = self.target.gradient(&position);
196 momentum = momentum + self.stepsize * gradient;
197 }
198 }
199
200 let gradient = self.target.gradient(&position);
202 momentum = momentum + 0.5 * self.stepsize * gradient;
203
204 momentum = -momentum;
206
207 Ok((position, momentum))
208 }
209
210 pub fn sample<R: Rng + ?Sized>(
212 &mut self,
213 n_samples_: usize,
214 rng: &mut R,
215 ) -> Result<Array2<f64>> {
216 let dim = self.position.len();
217 let mut samples = Array2::zeros((n_samples_, dim));
218
219 for i in 0..n_samples_ {
220 let sample = self.step(rng)?;
221 samples.row_mut(i).assign(&sample);
222 }
223
224 Ok(samples)
225 }
226
227 pub fn sample_with_burnin<R: Rng + ?Sized>(
229 &mut self,
230 n_samples_: usize,
231 burnin: usize,
232 rng: &mut R,
233 ) -> Result<Array2<f64>> {
234 check_positive(burnin, "burnin")?;
235
236 for _ in 0..burnin {
238 self.step(rng)?;
239 }
240
241 self.reset_counters();
243
244 self.sample(n_samples_, rng)
246 }
247
248 pub fn acceptance_rate(&self) -> f64 {
250 if self.n_proposed == 0 {
251 0.0
252 } else {
253 self.n_accepted as f64 / self.n_proposed as f64
254 }
255 }
256
257 pub fn reset_counters(&mut self) {
259 self.n_accepted = 0;
260 self.n_proposed = 0;
261 }
262}
263
264pub struct NoUTurnSampler<T: DifferentiableTarget> {
266 pub hmc: HamiltonianMonteCarlo<T>,
268 pub max_tree_depth: usize,
270 pub target_accept_prob: f64,
272 pub stepsize_adaptation: DualAveragingAdaptation,
274}
275
276#[derive(Debug, Clone)]
278pub struct DualAveragingAdaptation {
279 pub target: f64,
281 pub gamma: f64,
283 pub t0: f64,
285 pub kappa: f64,
287 pub iteration: usize,
289 pub log_step_avg: f64,
291 pub h_avg: f64,
293}
294
295impl DualAveragingAdaptation {
296 pub fn new(target: f64, initial_logstep: f64) -> Self {
298 Self {
299 target,
300 gamma: 0.05,
301 t0: 10.0,
302 kappa: 0.75,
303 iteration: 0,
304 log_step_avg: initial_logstep,
305 h_avg: 0.0,
306 }
307 }
308
309 pub fn update(&mut self, alpha: f64) -> f64 {
311 self.iteration += 1;
312 let m = self.iteration as f64;
313
314 self.h_avg =
316 (1.0 - 1.0 / (m + self.t0)) * self.h_avg + (self.target - alpha) / (m + self.t0);
317
318 let log_step = self.log_step_avg - self.h_avg / (self.gamma * m.powf(self.kappa));
320
321 let weight = m.powf(-self.kappa);
323 self.log_step_avg = (1.0 - weight) * self.log_step_avg + weight * log_step;
324
325 log_step.exp()
326 }
327}
328
329impl<T: DifferentiableTarget> NoUTurnSampler<T> {
330 pub fn new(target: T, initial: Array1<f64>, initial_stepsize: f64) -> Result<Self> {
332 let hmc = HamiltonianMonteCarlo::new(target, initial, initial_stepsize, 1)?;
333 let stepsize_adaptation = DualAveragingAdaptation::new(0.8, initial_stepsize.ln());
334
335 Ok(Self {
336 hmc,
337 max_tree_depth: 10,
338 target_accept_prob: 0.8,
339 stepsize_adaptation,
340 })
341 }
342
343 pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<Array1<f64>> {
345 let momentum = self.hmc.sample_momentum(rng)?;
347
348 let (new_position, alpha) =
350 self.build_tree(self.hmc.position.clone(), momentum, 0.0, 1, rng)?;
351
352 let new_stepsize = self.stepsize_adaptation.update(alpha);
354 self.hmc.stepsize = new_stepsize;
355
356 if !new_position
358 .iter()
359 .zip(self.hmc.position.iter())
360 .all(|(a, b)| (a - b).abs() < f64::EPSILON)
361 {
362 self.hmc.position = new_position;
363 self.hmc.current_log_density = self.hmc.target.log_density(&self.hmc.position);
364 self.hmc.n_accepted += 1;
365 }
366
367 self.hmc.n_proposed += 1;
368 Ok(self.hmc.position.clone())
369 }
370
371 fn build_tree<R: Rng + ?Sized>(
373 &self,
374 position: Array1<f64>,
375 momentum: Array1<f64>,
376 log_u: f64,
377 depth: usize,
378 rng: &mut R,
379 ) -> Result<(Array1<f64>, f64)> {
380 if depth >= self.max_tree_depth {
381 return Ok((position, 0.0));
383 }
384
385 let (new_position, new_momentum) = self.hmc.leapfrog(position.clone(), momentum.clone())?;
387
388 let new_log_density = self.hmc.target.log_density(&new_position);
390 let new_hamiltonian = -new_log_density + 0.5 * self.hmc.kinetic_energy(&new_momentum);
391
392 let current_hamiltonian =
394 -self.hmc.current_log_density + 0.5 * self.hmc.kinetic_energy(&momentum);
395 let log_alpha = -(new_hamiltonian - current_hamiltonian);
396 let alpha = log_alpha.exp().min(1.0);
397
398 if log_u <= log_alpha {
399 Ok((new_position, alpha))
400 } else {
401 Ok((position, alpha))
402 }
403 }
404
405 pub fn sample_adaptive<R: Rng + ?Sized>(
407 &mut self,
408 n_samples_: usize,
409 n_adapt: usize,
410 rng: &mut R,
411 ) -> Result<Array2<f64>> {
412 for _ in 0..n_adapt {
414 self.step(rng)?;
415 }
416
417 self.hmc.reset_counters();
419
420 let dim = self.hmc.position.len();
422 let mut samples = Array2::zeros((n_samples_, dim));
423
424 for i in 0..n_samples_ {
425 let sample = self.step(rng)?;
426 samples.row_mut(i).assign(&sample);
427 }
428
429 Ok(samples)
430 }
431}
432
433#[derive(Debug, Clone)]
437pub struct MultivariateNormalHMC {
438 pub mean: Array1<f64>,
440 pub precision: Array2<f64>,
442 pub log_norm_const: f64,
444}
445
446impl MultivariateNormalHMC {
447 pub fn new(mean: Array1<f64>, covariance: Array2<f64>) -> Result<Self> {
449 checkarray_finite(&mean, "mean")?;
450 checkarray_finite(&covariance, "covariance")?;
451
452 if covariance.nrows() != mean.len() || covariance.ncols() != mean.len() {
453 return Err(StatsError::DimensionMismatch(format!(
454 "covariance shape ({}, {}) must be ({}, {})",
455 covariance.nrows(),
456 covariance.ncols(),
457 mean.len(),
458 mean.len()
459 )));
460 }
461
462 let precision = scirs2_linalg::inv(&covariance.view(), None).map_err(|e| {
463 StatsError::ComputationError(format!("Failed to invert covariance: {}", e))
464 })?;
465
466 let det = scirs2_linalg::det(&covariance.view(), None).map_err(|e| {
467 StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
468 })?;
469
470 if det <= 0.0 {
471 return Err(StatsError::InvalidArgument(
472 "Covariance must be positive definite".to_string(),
473 ));
474 }
475
476 let d = mean.len() as f64;
477 let log_norm_const = -0.5 * (d * (2.0 * std::f64::consts::PI).ln() + det.ln());
478
479 Ok(Self {
480 mean,
481 precision,
482 log_norm_const,
483 })
484 }
485}
486
487impl DifferentiableTarget for MultivariateNormalHMC {
488 fn log_density(&self, x: &Array1<f64>) -> f64 {
489 let diff = x - &self.mean;
490 let quad_form = diff.dot(&self.precision.dot(&diff));
491 self.log_norm_const - 0.5 * quad_form
492 }
493
494 fn gradient(&self, x: &Array1<f64>) -> Array1<f64> {
495 let diff = x - &self.mean;
496 -self.precision.dot(&diff)
497 }
498
499 fn dim(&self) -> usize {
500 self.mean.len()
501 }
502
503 fn log_density_and_gradient(&self, x: &Array1<f64>) -> (f64, Array1<f64>) {
504 let diff = x - &self.mean;
505 let quad_form = diff.dot(&self.precision.dot(&diff));
506 let log_density = self.log_norm_const - 0.5 * quad_form;
507 let gradient = -self.precision.dot(&diff);
508 (log_density, gradient)
509 }
510}
511
512pub struct CustomDifferentiableTarget<F, G> {
514 pub log_density_fn: F,
516 pub gradient_fn: G,
518 pub dim: usize,
520}
521
522impl<F, G> CustomDifferentiableTarget<F, G> {
523 pub fn new(dim: usize, log_density_fn: F, gradientfn: G) -> Result<Self> {
525 check_positive(dim, "dim")?;
526 Ok(Self {
527 log_density_fn,
528 gradient_fn: gradientfn,
529 dim,
530 })
531 }
532}
533
534impl<F, G> DifferentiableTarget for CustomDifferentiableTarget<F, G>
535where
536 F: Fn(&Array1<f64>) -> f64 + Send + Sync,
537 G: Fn(&Array1<f64>) -> Array1<f64> + Send + Sync,
538{
539 fn log_density(&self, x: &Array1<f64>) -> f64 {
540 (self.log_density_fn)(x)
541 }
542
543 fn gradient(&self, x: &Array1<f64>) -> Array1<f64> {
544 (self.gradient_fn)(x)
545 }
546
547 fn dim(&self) -> usize {
548 self.dim
549 }
550}