1use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView2};
8use scirs2_core::random::{thread_rng, Distribution, RandNormal, Rng, SeedableRng};
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 types::Float,
12};
13
14#[derive(Debug, Clone)]
51pub struct NUTSSampler {
52 pub(crate) n_components: usize,
53 pub(crate) n_samples: usize,
54 pub(crate) n_warmup: usize,
55 pub(crate) step_size: f64,
56 pub(crate) target_accept_rate: f64,
57 pub(crate) max_tree_depth: usize,
58 pub(crate) adapt_step_size: bool,
59 pub(crate) adapt_mass_matrix: bool,
60 pub(crate) random_state: Option<u64>,
61}
62
63#[derive(Debug, Clone)]
65pub struct NUTSResult {
66 pub(crate) weights_samples: Array2<f64>,
67 pub(crate) means_samples: Array3<f64>,
68 pub(crate) covariances_samples: Vec<Array3<f64>>,
69 pub(crate) log_posterior_samples: Array1<f64>,
70 pub(crate) acceptance_rate: f64,
71 pub(crate) step_size_final: f64,
72 pub(crate) n_divergent: usize,
73 pub(crate) tree_depth_samples: Array1<usize>,
74}
75
76#[derive(Debug, Clone)]
78struct TreeNode {
79 position: Array1<f64>,
80 momentum: Array1<f64>,
81 log_posterior: f64,
82 gradient: Array1<f64>,
83}
84
85#[derive(Debug)]
87struct TreeState {
88 left_node: TreeNode,
89 right_node: TreeNode,
90 proposal: TreeNode,
91 n_proposals: usize,
92 sum_momentum: Array1<f64>,
93 sum_momentum_squared: f64,
94 valid: bool,
95}
96
97impl NUTSSampler {
98 pub fn new() -> Self {
100 Self {
101 n_components: 1,
102 n_samples: 1000,
103 n_warmup: 500,
104 step_size: 0.1,
105 target_accept_rate: 0.8,
106 max_tree_depth: 10,
107 adapt_step_size: true,
108 adapt_mass_matrix: true,
109 random_state: None,
110 }
111 }
112
113 pub fn builder() -> Self {
115 Self::new()
116 }
117
118 pub fn n_components(mut self, n_components: usize) -> Self {
120 self.n_components = n_components;
121 self
122 }
123
124 pub fn n_samples(mut self, n_samples: usize) -> Self {
126 self.n_samples = n_samples;
127 self
128 }
129
130 pub fn n_warmup(mut self, n_warmup: usize) -> Self {
132 self.n_warmup = n_warmup;
133 self
134 }
135
136 pub fn step_size(mut self, step_size: f64) -> Self {
138 self.step_size = step_size;
139 self
140 }
141
142 pub fn target_accept_rate(mut self, rate: f64) -> Self {
144 self.target_accept_rate = rate;
145 self
146 }
147
148 pub fn max_tree_depth(mut self, depth: usize) -> Self {
150 self.max_tree_depth = depth;
151 self
152 }
153
154 pub fn adapt_step_size(mut self, adapt: bool) -> Self {
156 self.adapt_step_size = adapt;
157 self
158 }
159
160 pub fn adapt_mass_matrix(mut self, adapt: bool) -> Self {
162 self.adapt_mass_matrix = adapt;
163 self
164 }
165
166 pub fn random_state(mut self, random_state: u64) -> Self {
168 self.random_state = Some(random_state);
169 self
170 }
171
172 pub fn build(self) -> Self {
174 self
175 }
176
177 #[allow(non_snake_case)]
179 pub fn sample(&self, X: &ArrayView2<Float>) -> SklResult<NUTSResult> {
180 let X = X.to_owned();
181 let (n_observations, n_features) = X.dim();
182
183 if n_observations < 2 {
184 return Err(SklearsError::InvalidInput(
185 "Number of observations must be at least 2".to_string(),
186 ));
187 }
188
189 if self.n_components == 0 {
190 return Err(SklearsError::InvalidInput(
191 "Number of components must be positive".to_string(),
192 ));
193 }
194
195 let mut rng = if let Some(seed) = self.random_state {
196 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
197 } else {
198 scirs2_core::random::rngs::StdRng::from_rng(&mut thread_rng())
199 };
200
201 let n_params = self.calculate_n_parameters(n_features);
203 let mut current_position = self.initialize_parameters(n_features, &mut rng)?;
204
205 let mass_matrix = Array1::ones(n_params);
207
208 let mut weights_samples = Array2::zeros((self.n_samples, self.n_components));
210 let mut means_samples = Array3::zeros((self.n_samples, self.n_components, n_features));
211 let mut covariances_samples = Vec::new();
212 for _ in 0..self.n_components {
213 covariances_samples.push(Array3::zeros((self.n_samples, n_features, n_features)));
214 }
215 let mut log_posterior_samples = Array1::zeros(self.n_samples);
216 let mut tree_depth_samples = Array1::zeros(self.n_samples);
217
218 let mut step_size = self.step_size;
220 let mut n_accepted = 0;
221 let mut n_divergent = 0;
222
223 for sample_idx in 0..(self.n_samples + self.n_warmup) {
225 let is_warmup = sample_idx < self.n_warmup;
226
227 let momentum = self.sample_momentum(&mass_matrix, &mut rng)?;
229
230 let (log_posterior, gradient) =
232 self.compute_log_posterior_and_gradient(&X, ¤t_position, n_features)?;
233
234 let tree_result = self.build_tree(
236 ¤t_position,
237 &momentum,
238 log_posterior,
239 &gradient,
240 &X,
241 n_features,
242 step_size,
243 &mass_matrix,
244 &mut rng,
245 )?;
246
247 let accept_prob = (tree_result.proposal.log_posterior
249 - (log_posterior + 0.5 * self.kinetic_energy(&momentum, &mass_matrix)))
250 .exp()
251 .min(1.0);
252
253 let accept = rng.gen::<f64>() < accept_prob;
254
255 if accept {
256 current_position = tree_result.proposal.position.clone();
257 n_accepted += 1;
258 }
259
260 if !tree_result.valid {
261 n_divergent += 1;
262 }
263
264 if !is_warmup {
266 let sample_idx_adjusted = sample_idx - self.n_warmup;
267 self.store_sample(
268 sample_idx_adjusted,
269 ¤t_position,
270 n_features,
271 &mut weights_samples,
272 &mut means_samples,
273 &mut covariances_samples,
274 &mut log_posterior_samples,
275 &mut tree_depth_samples,
276 tree_result.tree_depth,
277 )?;
278 }
279
280 if is_warmup && self.adapt_step_size {
282 step_size =
283 self.adapt_step_size_dual_averaging(step_size, accept_prob, sample_idx + 1);
284 }
285 }
286
287 let acceptance_rate = n_accepted as f64 / (self.n_samples + self.n_warmup) as f64;
288
289 Ok(NUTSResult {
290 weights_samples,
291 means_samples,
292 covariances_samples,
293 log_posterior_samples,
294 acceptance_rate,
295 step_size_final: step_size,
296 n_divergent,
297 tree_depth_samples,
298 })
299 }
300
301 fn calculate_n_parameters(&self, n_features: usize) -> usize {
303 (self.n_components - 1)
306 + (self.n_components * n_features)
307 + (self.n_components * n_features)
308 }
309
310 fn initialize_parameters(
312 &self,
313 n_features: usize,
314 rng: &mut scirs2_core::random::rngs::StdRng,
315 ) -> SklResult<Array1<f64>> {
316 let n_params = self.calculate_n_parameters(n_features);
317 let mut params = Array1::zeros(n_params);
318
319 for i in 0..n_params {
321 let normal = RandNormal::new(0.0, 0.1).map_err(|e| {
322 SklearsError::InvalidInput(format!("Normal distribution error: {}", e))
323 })?;
324 params[i] = rng.sample(normal);
325 }
326
327 Ok(params)
328 }
329
330 fn sample_momentum(
332 &self,
333 mass_matrix: &Array1<f64>,
334 rng: &mut scirs2_core::random::rngs::StdRng,
335 ) -> SklResult<Array1<f64>> {
336 let n_params = mass_matrix.len();
337 let mut momentum = Array1::zeros(n_params);
338
339 for i in 0..n_params {
340 let std_dev = mass_matrix[i].sqrt();
341 momentum[i] = RandNormal::new(0.0, std_dev)
342 .map_err(|e| {
343 SklearsError::InvalidInput(format!("Normal distribution error: {}", e))
344 })?
345 .sample(rng);
346 }
347
348 Ok(momentum)
349 }
350
351 fn compute_log_posterior_and_gradient(
353 &self,
354 _X: &Array2<f64>,
355 position: &Array1<f64>,
356 _n_features: usize,
357 ) -> SklResult<(f64, Array1<f64>)> {
358 let log_posterior = -0.5 * position.mapv(|x| x * x).sum(); let gradient = -position.clone(); Ok((log_posterior, gradient))
366 }
367
368 fn kinetic_energy(&self, momentum: &Array1<f64>, mass_matrix: &Array1<f64>) -> f64 {
370 0.5 * (momentum * momentum / mass_matrix).sum()
371 }
372
373 fn build_tree(
375 &self,
376 position: &Array1<f64>,
377 momentum: &Array1<f64>,
378 log_posterior: f64,
379 gradient: &Array1<f64>,
380 X: &Array2<f64>,
381 n_features: usize,
382 step_size: f64,
383 mass_matrix: &Array1<f64>,
384 rng: &mut scirs2_core::random::rngs::StdRng,
385 ) -> SklResult<TreeResult> {
386 let initial_node = TreeNode {
388 position: position.clone(),
389 momentum: momentum.clone(),
390 log_posterior,
391 gradient: gradient.clone(),
392 };
393
394 let mut tree_state = TreeState {
395 left_node: initial_node.clone(),
396 right_node: initial_node.clone(),
397 proposal: initial_node.clone(),
398 n_proposals: 1,
399 sum_momentum: momentum.clone(),
400 sum_momentum_squared: momentum.mapv(|x| x * x).sum(),
401 valid: true,
402 };
403
404 let mut tree_depth = 0;
405
406 for depth in 0..self.max_tree_depth {
408 tree_depth = depth;
409
410 let direction = if rng.gen::<f64>() < 0.5 { 1.0 } else { -1.0 };
412
413 let subtree = self.build_subtree(
415 if direction > 0.0 {
416 &tree_state.right_node
417 } else {
418 &tree_state.left_node
419 },
420 direction,
421 depth,
422 step_size,
423 X,
424 n_features,
425 mass_matrix,
426 rng,
427 )?;
428
429 if !subtree.valid || self.check_uturn(&tree_state, &subtree) {
431 break;
432 }
433
434 if direction > 0.0 {
436 tree_state.right_node = subtree.right_node;
437 } else {
438 tree_state.left_node = subtree.left_node;
439 }
440
441 let accept_prob =
443 subtree.n_proposals as f64 / (tree_state.n_proposals + subtree.n_proposals) as f64;
444 if rng.gen::<f64>() < accept_prob {
445 tree_state.proposal = subtree.proposal;
446 }
447
448 tree_state.n_proposals += subtree.n_proposals;
449 tree_state.sum_momentum = &tree_state.sum_momentum + &subtree.sum_momentum;
450 tree_state.sum_momentum_squared += subtree.sum_momentum_squared;
451 }
452
453 Ok(TreeResult {
454 proposal: tree_state.proposal,
455 valid: tree_state.valid,
456 tree_depth,
457 })
458 }
459
460 fn build_subtree(
462 &self,
463 node: &TreeNode,
464 direction: f64,
465 depth: usize,
466 step_size: f64,
467 X: &Array2<f64>,
468 n_features: usize,
469 mass_matrix: &Array1<f64>,
470 rng: &mut scirs2_core::random::rngs::StdRng,
471 ) -> SklResult<TreeState> {
472 if depth == 0 {
473 let new_node =
475 self.leapfrog_step(node, direction * step_size, X, n_features, mass_matrix)?;
476
477 Ok(TreeState {
478 left_node: new_node.clone(),
479 right_node: new_node.clone(),
480 proposal: new_node.clone(),
481 n_proposals: 1,
482 sum_momentum: new_node.momentum.clone(),
483 sum_momentum_squared: new_node.momentum.mapv(|x| x * x).sum(),
484 valid: true,
485 })
486 } else {
487 let left_subtree = self.build_subtree(
489 node,
490 direction,
491 depth - 1,
492 step_size,
493 X,
494 n_features,
495 mass_matrix,
496 rng,
497 )?;
498
499 if !left_subtree.valid {
500 return Ok(left_subtree);
501 }
502
503 let right_subtree = self.build_subtree(
504 if direction > 0.0 {
505 &left_subtree.right_node
506 } else {
507 &left_subtree.left_node
508 },
509 direction,
510 depth - 1,
511 step_size,
512 X,
513 n_features,
514 mass_matrix,
515 rng,
516 )?;
517
518 if !right_subtree.valid {
519 return Ok(TreeState {
520 left_node: left_subtree.left_node,
521 right_node: left_subtree.right_node,
522 proposal: left_subtree.proposal,
523 n_proposals: left_subtree.n_proposals,
524 sum_momentum: left_subtree.sum_momentum,
525 sum_momentum_squared: left_subtree.sum_momentum_squared,
526 valid: false,
527 });
528 }
529
530 let total_proposals = left_subtree.n_proposals + right_subtree.n_proposals;
532 let proposal =
533 if rng.gen::<f64>() < (right_subtree.n_proposals as f64 / total_proposals as f64) {
534 right_subtree.proposal
535 } else {
536 left_subtree.proposal
537 };
538
539 Ok(TreeState {
540 left_node: if direction > 0.0 {
541 left_subtree.left_node
542 } else {
543 right_subtree.left_node
544 },
545 right_node: if direction > 0.0 {
546 right_subtree.right_node
547 } else {
548 left_subtree.right_node
549 },
550 proposal,
551 n_proposals: total_proposals,
552 sum_momentum: &left_subtree.sum_momentum + &right_subtree.sum_momentum,
553 sum_momentum_squared: left_subtree.sum_momentum_squared
554 + right_subtree.sum_momentum_squared,
555 valid: true,
556 })
557 }
558 }
559
560 fn leapfrog_step(
562 &self,
563 node: &TreeNode,
564 step_size: f64,
565 X: &Array2<f64>,
566 n_features: usize,
567 mass_matrix: &Array1<f64>,
568 ) -> SklResult<TreeNode> {
569 let momentum_half = &node.momentum + 0.5 * step_size * &node.gradient;
571
572 let new_position = &node.position + step_size * (&momentum_half / mass_matrix);
574
575 let (new_log_posterior, new_gradient) =
577 self.compute_log_posterior_and_gradient(X, &new_position, n_features)?;
578
579 let new_momentum = &momentum_half + 0.5 * step_size * &new_gradient;
581
582 Ok(TreeNode {
583 position: new_position,
584 momentum: new_momentum,
585 log_posterior: new_log_posterior,
586 gradient: new_gradient,
587 })
588 }
589
590 fn check_uturn(&self, tree_state: &TreeState, _subtree: &TreeState) -> bool {
592 let left_momentum = &tree_state.left_node.momentum;
594 let right_momentum = &tree_state.right_node.momentum;
595 let momentum_diff = right_momentum - left_momentum;
596
597 left_momentum.dot(&momentum_diff) < 0.0 || right_momentum.dot(&momentum_diff) < 0.0
599 }
600
601 fn adapt_step_size_dual_averaging(
603 &self,
604 current_step_size: f64,
605 accept_prob: f64,
606 iteration: usize,
607 ) -> f64 {
608 let delta = self.target_accept_rate - accept_prob;
609 let gamma = 0.05;
610 let t0 = 10.0;
611 let kappa = 0.75;
612
613 let eta = 1.0 / (iteration as f64 + t0);
614 let log_step_size = current_step_size.ln() + eta * delta;
615
616 (log_step_size - gamma * (iteration as f64).powf(-kappa) * delta).exp()
617 }
618
619 fn store_sample(
621 &self,
622 sample_idx: usize,
623 position: &Array1<f64>,
624 n_features: usize,
625 weights_samples: &mut Array2<f64>,
626 means_samples: &mut Array3<f64>,
627 covariances_samples: &mut Vec<Array3<f64>>,
628 log_posterior_samples: &mut Array1<f64>,
629 tree_depth_samples: &mut Array1<usize>,
630 tree_depth: usize,
631 ) -> SklResult<()> {
632 let (weights, means, covariances) = self.decode_parameters(position, n_features)?;
635
636 for k in 0..self.n_components {
638 weights_samples[[sample_idx, k]] = weights[k];
639 }
640
641 for k in 0..self.n_components {
643 for j in 0..n_features {
644 means_samples[[sample_idx, k, j]] = means[[k, j]];
645 }
646 }
647
648 for k in 0..self.n_components {
650 for i in 0..n_features {
651 for j in 0..n_features {
652 covariances_samples[k][[sample_idx, i, j]] =
653 if i == j { covariances[k] } else { 0.0 };
654 }
655 }
656 }
657
658 log_posterior_samples[sample_idx] = -0.5 * position.mapv(|x| x * x).sum();
660
661 tree_depth_samples[sample_idx] = tree_depth;
663
664 Ok(())
665 }
666
667 fn decode_parameters(
669 &self,
670 _position: &Array1<f64>,
671 n_features: usize,
672 ) -> SklResult<(Array1<f64>, Array2<f64>, Array1<f64>)> {
673 let weights = Array1::ones(self.n_components) / self.n_components as f64;
675 let means = Array2::zeros((self.n_components, n_features));
676 let covariances = Array1::ones(self.n_components);
677
678 Ok((weights, means, covariances))
682 }
683}
684
685#[derive(Debug)]
687struct TreeResult {
688 proposal: TreeNode,
689 valid: bool,
690 tree_depth: usize,
691}
692
693impl Default for NUTSSampler {
694 fn default() -> Self {
695 Self::new()
696 }
697}
698
699impl NUTSResult {
700 pub fn weights_samples(&self) -> &Array2<f64> {
702 &self.weights_samples
703 }
704
705 pub fn means_samples(&self) -> &Array3<f64> {
707 &self.means_samples
708 }
709
710 pub fn covariances_samples(&self) -> &[Array3<f64>] {
712 &self.covariances_samples
713 }
714
715 pub fn log_posterior_samples(&self) -> &Array1<f64> {
717 &self.log_posterior_samples
718 }
719
720 pub fn acceptance_rate(&self) -> f64 {
722 self.acceptance_rate
723 }
724
725 pub fn step_size_final(&self) -> f64 {
727 self.step_size_final
728 }
729
730 pub fn n_divergent(&self) -> usize {
732 self.n_divergent
733 }
734
735 pub fn tree_depth_samples(&self) -> &Array1<usize> {
737 &self.tree_depth_samples
738 }
739
740 pub fn posterior_means(&self) -> SklResult<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
742 let n_samples = self.weights_samples.shape()[0];
743 let n_components = self.weights_samples.shape()[1];
744 let n_features = self.means_samples.shape()[2];
745
746 let mut mean_weights = Array1::zeros(n_components);
748 for k in 0..n_components {
749 mean_weights[k] = self.weights_samples.column(k).mean().unwrap_or(0.0);
750 }
751
752 let mut mean_means = Array2::zeros((n_components, n_features));
754 for k in 0..n_components {
755 for j in 0..n_features {
756 let values: Vec<f64> = (0..n_samples)
757 .map(|i| self.means_samples[[i, k, j]])
758 .collect();
759 mean_means[[k, j]] = values.iter().sum::<f64>() / n_samples as f64;
760 }
761 }
762
763 let mut mean_covariances = Vec::new();
765 for k in 0..n_components {
766 let mut mean_cov = Array2::zeros((n_features, n_features));
767 for i in 0..n_features {
768 for j in 0..n_features {
769 let values: Vec<f64> = (0..n_samples)
770 .map(|s| self.covariances_samples[k][[s, i, j]])
771 .collect();
772 mean_cov[[i, j]] = values.iter().sum::<f64>() / n_samples as f64;
773 }
774 }
775 mean_covariances.push(mean_cov);
776 }
777
778 Ok((mean_weights, mean_means, mean_covariances))
779 }
780
781 pub fn credible_intervals(&self, alpha: f64) -> SklResult<(Array2<f64>, Array3<f64>)> {
783 let n_components = self.weights_samples.shape()[1];
784 let n_features = self.means_samples.shape()[2];
785
786 let mut weight_intervals = Array2::zeros((n_components, 2));
788 for k in 0..n_components {
789 let mut values: Vec<f64> = self.weights_samples.column(k).to_vec();
790 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
791 let lower_idx = ((alpha / 2.0) * values.len() as f64) as usize;
792 let upper_idx = ((1.0 - alpha / 2.0) * values.len() as f64) as usize;
793 weight_intervals[[k, 0]] = values[lower_idx];
794 weight_intervals[[k, 1]] = values[upper_idx.min(values.len() - 1)];
795 }
796
797 let mut mean_intervals = Array3::zeros((n_components, n_features, 2));
799 for k in 0..n_components {
800 for j in 0..n_features {
801 let mut values: Vec<f64> = (0..self.means_samples.shape()[0])
802 .map(|i| self.means_samples[[i, k, j]])
803 .collect();
804 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
805 let lower_idx = ((alpha / 2.0) * values.len() as f64) as usize;
806 let upper_idx = ((1.0 - alpha / 2.0) * values.len() as f64) as usize;
807 mean_intervals[[k, j, 0]] = values[lower_idx];
808 mean_intervals[[k, j, 1]] = values[upper_idx.min(values.len() - 1)];
809 }
810 }
811
812 Ok((weight_intervals, mean_intervals))
813 }
814}
815
816#[allow(non_snake_case)]
817#[cfg(test)]
818mod tests {
819 use super::*;
820 use approx::assert_relative_eq;
821 use scirs2_core::ndarray::array;
822
823 #[test]
824 #[allow(non_snake_case)]
825 fn test_nuts_sampler_basic() {
826 let X = array![[0.0, 0.0], [1.0, 1.0], [10.0, 10.0], [11.0, 11.0]];
827
828 let nuts = NUTSSampler::new()
829 .n_components(2)
830 .n_samples(10)
831 .n_warmup(5)
832 .max_tree_depth(3)
833 .random_state(42);
834
835 let result = nuts.sample(&X.view()).unwrap();
836
837 assert_eq!(result.weights_samples.shape(), &[10, 2]);
838 assert_eq!(result.means_samples.shape(), &[10, 2, 2]);
839 assert_eq!(result.covariances_samples.len(), 2);
840 assert!(result.acceptance_rate >= 0.0 && result.acceptance_rate <= 1.0);
841 }
842
843 #[test]
844 fn test_nuts_sampler_builder() {
845 let nuts = NUTSSampler::builder()
846 .n_components(3)
847 .n_samples(100)
848 .n_warmup(50)
849 .step_size(0.05)
850 .target_accept_rate(0.85)
851 .max_tree_depth(12)
852 .adapt_step_size(true)
853 .adapt_mass_matrix(false)
854 .random_state(123)
855 .build();
856
857 assert_eq!(nuts.n_components, 3);
858 assert_eq!(nuts.n_samples, 100);
859 assert_eq!(nuts.n_warmup, 50);
860 assert_relative_eq!(nuts.step_size, 0.05);
861 assert_relative_eq!(nuts.target_accept_rate, 0.85);
862 assert_eq!(nuts.max_tree_depth, 12);
863 assert!(nuts.adapt_step_size);
864 assert!(!nuts.adapt_mass_matrix);
865 }
866
867 #[test]
868 #[allow(non_snake_case)]
869 fn test_nuts_result_analysis() {
870 let X = array![[0.0], [1.0], [2.0]];
871
872 let nuts = NUTSSampler::new()
873 .n_components(1)
874 .n_samples(5)
875 .n_warmup(2)
876 .random_state(42);
877
878 let result = nuts.sample(&X.view()).unwrap();
879
880 let (mean_weights, mean_means, mean_covariances) = result.posterior_means().unwrap();
882 assert_eq!(mean_weights.len(), 1);
883 assert_eq!(mean_means.shape(), &[1, 1]);
884 assert_eq!(mean_covariances.len(), 1);
885
886 let (weight_intervals, mean_intervals) = result.credible_intervals(0.05).unwrap();
888 assert_eq!(weight_intervals.shape(), &[1, 2]);
889 assert_eq!(mean_intervals.shape(), &[1, 1, 2]);
890 }
891
892 #[test]
893 fn test_kinetic_energy() {
894 let nuts = NUTSSampler::new();
895 let momentum = array![1.0, 2.0, 3.0];
896 let mass_matrix = array![1.0, 1.0, 1.0];
897
898 let ke = nuts.kinetic_energy(&momentum, &mass_matrix);
899 assert_relative_eq!(ke, 7.0); }
901
902 #[test]
903 fn test_step_size_adaptation() {
904 let nuts = NUTSSampler::new().target_accept_rate(0.8);
905
906 let step_size = nuts.adapt_step_size_dual_averaging(0.1, 0.6, 10);
907 assert!(step_size > 0.0);
908
909 let step_size2 = nuts.adapt_step_size_dual_averaging(0.1, 0.9, 10);
910 assert!(step_size2 > 0.0);
911 }
912}