sklears_mixture/
nuts.rs

1//! No-U-Turn Sampler (NUTS) for Bayesian Mixture Models
2//!
3//! This module implements the No-U-Turn Sampler (NUTS), an extension of Hamiltonian Monte Carlo
4//! that adaptively determines the number of leapfrog steps. NUTS is particularly effective
5//! for sampling from high-dimensional posterior distributions in Bayesian mixture models.
6
7use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView2};
8use scirs2_core::random::{thread_rng, Distribution, RandNormal, Rng, SeedableRng};
9use sklears_core::{
10    error::{Result as SklResult, SklearsError},
11    types::Float,
12};
13
14/// No-U-Turn Sampler for Bayesian Mixture Models
15///
16/// NUTS is an adaptive variant of Hamiltonian Monte Carlo that automatically determines
17/// the number of leapfrog steps by continuing until the trajectory starts to "turn around"
18/// and move back towards its starting point. This makes it highly efficient for exploring
19/// complex posterior distributions without requiring manual tuning of the trajectory length.
20///
21/// # Parameters
22///
23/// * `n_components` - Number of mixture components
24/// * `n_samples` - Number of MCMC samples to draw
25/// * `n_warmup` - Number of warmup samples for adaptation
26/// * `step_size` - Initial step size for leapfrog integration
27/// * `target_accept_rate` - Target acceptance rate for step size adaptation
28/// * `max_tree_depth` - Maximum tree depth to prevent infinite loops
29/// * `adapt_step_size` - Whether to adapt step size during warmup
30/// * `adapt_mass_matrix` - Whether to adapt the mass matrix during warmup
31/// * `random_state` - Random state for reproducibility
32///
33/// # Examples
34///
35/// ```
36/// use sklears_mixture::{NUTSSampler, CovarianceType};
37/// use scirs2_core::ndarray::array;
38///
39/// let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [10.0, 10.0], [11.0, 11.0], [12.0, 12.0]];
40///
41/// let nuts = NUTSSampler::new()
42///     .n_components(2)
43///     .n_samples(1000)
44///     .n_warmup(500)
45///     .target_accept_rate(0.8)
46///     .max_tree_depth(10);
47///
48/// let result = nuts.sample(&X.view()).unwrap();
49/// ```
50#[derive(Debug, Clone)]
51pub struct NUTSSampler {
52    pub(crate) n_components: usize,
53    pub(crate) n_samples: usize,
54    pub(crate) n_warmup: usize,
55    pub(crate) step_size: f64,
56    pub(crate) target_accept_rate: f64,
57    pub(crate) max_tree_depth: usize,
58    pub(crate) adapt_step_size: bool,
59    pub(crate) adapt_mass_matrix: bool,
60    pub(crate) random_state: Option<u64>,
61}
62
63/// NUTS sampling result
64#[derive(Debug, Clone)]
65pub struct NUTSResult {
66    pub(crate) weights_samples: Array2<f64>,
67    pub(crate) means_samples: Array3<f64>,
68    pub(crate) covariances_samples: Vec<Array3<f64>>,
69    pub(crate) log_posterior_samples: Array1<f64>,
70    pub(crate) acceptance_rate: f64,
71    pub(crate) step_size_final: f64,
72    pub(crate) n_divergent: usize,
73    pub(crate) tree_depth_samples: Array1<usize>,
74}
75
76/// NUTS tree node for building the trajectory
77#[derive(Debug, Clone)]
78struct TreeNode {
79    position: Array1<f64>,
80    momentum: Array1<f64>,
81    log_posterior: f64,
82    gradient: Array1<f64>,
83}
84
85/// NUTS tree state
86#[derive(Debug)]
87struct TreeState {
88    left_node: TreeNode,
89    right_node: TreeNode,
90    proposal: TreeNode,
91    n_proposals: usize,
92    sum_momentum: Array1<f64>,
93    sum_momentum_squared: f64,
94    valid: bool,
95}
96
97impl NUTSSampler {
98    /// Create a new NUTSSampler instance
99    pub fn new() -> Self {
100        Self {
101            n_components: 1,
102            n_samples: 1000,
103            n_warmup: 500,
104            step_size: 0.1,
105            target_accept_rate: 0.8,
106            max_tree_depth: 10,
107            adapt_step_size: true,
108            adapt_mass_matrix: true,
109            random_state: None,
110        }
111    }
112
113    /// Create a new NUTSSampler instance using builder pattern (alias for new)
114    pub fn builder() -> Self {
115        Self::new()
116    }
117
118    /// Set the number of components
119    pub fn n_components(mut self, n_components: usize) -> Self {
120        self.n_components = n_components;
121        self
122    }
123
124    /// Set the number of samples to draw
125    pub fn n_samples(mut self, n_samples: usize) -> Self {
126        self.n_samples = n_samples;
127        self
128    }
129
130    /// Set the number of warmup samples
131    pub fn n_warmup(mut self, n_warmup: usize) -> Self {
132        self.n_warmup = n_warmup;
133        self
134    }
135
136    /// Set the initial step size
137    pub fn step_size(mut self, step_size: f64) -> Self {
138        self.step_size = step_size;
139        self
140    }
141
142    /// Set the target acceptance rate
143    pub fn target_accept_rate(mut self, rate: f64) -> Self {
144        self.target_accept_rate = rate;
145        self
146    }
147
148    /// Set the maximum tree depth
149    pub fn max_tree_depth(mut self, depth: usize) -> Self {
150        self.max_tree_depth = depth;
151        self
152    }
153
154    /// Set whether to adapt step size
155    pub fn adapt_step_size(mut self, adapt: bool) -> Self {
156        self.adapt_step_size = adapt;
157        self
158    }
159
160    /// Set whether to adapt mass matrix
161    pub fn adapt_mass_matrix(mut self, adapt: bool) -> Self {
162        self.adapt_mass_matrix = adapt;
163        self
164    }
165
166    /// Set the random state
167    pub fn random_state(mut self, random_state: u64) -> Self {
168        self.random_state = Some(random_state);
169        self
170    }
171
172    /// Build the NUTSSampler (builder pattern completion)
173    pub fn build(self) -> Self {
174        self
175    }
176
177    /// Sample from the posterior distribution of a Gaussian mixture model
178    #[allow(non_snake_case)]
179    pub fn sample(&self, X: &ArrayView2<Float>) -> SklResult<NUTSResult> {
180        let X = X.to_owned();
181        let (n_observations, n_features) = X.dim();
182
183        if n_observations < 2 {
184            return Err(SklearsError::InvalidInput(
185                "Number of observations must be at least 2".to_string(),
186            ));
187        }
188
189        if self.n_components == 0 {
190            return Err(SklearsError::InvalidInput(
191                "Number of components must be positive".to_string(),
192            ));
193        }
194
195        let mut rng = if let Some(seed) = self.random_state {
196            scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
197        } else {
198            scirs2_core::random::rngs::StdRng::from_rng(&mut thread_rng())
199        };
200
201        // Initialize parameters
202        let n_params = self.calculate_n_parameters(n_features);
203        let mut current_position = self.initialize_parameters(n_features, &mut rng)?;
204
205        // Initialize mass matrix (identity for simplicity)
206        let mass_matrix = Array1::ones(n_params);
207
208        // Storage for samples
209        let mut weights_samples = Array2::zeros((self.n_samples, self.n_components));
210        let mut means_samples = Array3::zeros((self.n_samples, self.n_components, n_features));
211        let mut covariances_samples = Vec::new();
212        for _ in 0..self.n_components {
213            covariances_samples.push(Array3::zeros((self.n_samples, n_features, n_features)));
214        }
215        let mut log_posterior_samples = Array1::zeros(self.n_samples);
216        let mut tree_depth_samples = Array1::zeros(self.n_samples);
217
218        // Adaptation variables
219        let mut step_size = self.step_size;
220        let mut n_accepted = 0;
221        let mut n_divergent = 0;
222
223        // NUTS sampling loop
224        for sample_idx in 0..(self.n_samples + self.n_warmup) {
225            let is_warmup = sample_idx < self.n_warmup;
226
227            // Sample momentum
228            let momentum = self.sample_momentum(&mass_matrix, &mut rng)?;
229
230            // Compute current log posterior and gradient
231            let (log_posterior, gradient) =
232                self.compute_log_posterior_and_gradient(&X, &current_position, n_features)?;
233
234            // Build NUTS tree
235            let tree_result = self.build_tree(
236                &current_position,
237                &momentum,
238                log_posterior,
239                &gradient,
240                &X,
241                n_features,
242                step_size,
243                &mass_matrix,
244                &mut rng,
245            )?;
246
247            // Accept or reject proposal
248            let accept_prob = (tree_result.proposal.log_posterior
249                - (log_posterior + 0.5 * self.kinetic_energy(&momentum, &mass_matrix)))
250            .exp()
251            .min(1.0);
252
253            let accept = rng.gen::<f64>() < accept_prob;
254
255            if accept {
256                current_position = tree_result.proposal.position.clone();
257                n_accepted += 1;
258            }
259
260            if !tree_result.valid {
261                n_divergent += 1;
262            }
263
264            // Store sample (skip warmup)
265            if !is_warmup {
266                let sample_idx_adjusted = sample_idx - self.n_warmup;
267                self.store_sample(
268                    sample_idx_adjusted,
269                    &current_position,
270                    n_features,
271                    &mut weights_samples,
272                    &mut means_samples,
273                    &mut covariances_samples,
274                    &mut log_posterior_samples,
275                    &mut tree_depth_samples,
276                    tree_result.tree_depth,
277                )?;
278            }
279
280            // Adaptation during warmup
281            if is_warmup && self.adapt_step_size {
282                step_size =
283                    self.adapt_step_size_dual_averaging(step_size, accept_prob, sample_idx + 1);
284            }
285        }
286
287        let acceptance_rate = n_accepted as f64 / (self.n_samples + self.n_warmup) as f64;
288
289        Ok(NUTSResult {
290            weights_samples,
291            means_samples,
292            covariances_samples,
293            log_posterior_samples,
294            acceptance_rate,
295            step_size_final: step_size,
296            n_divergent,
297            tree_depth_samples,
298        })
299    }
300
301    /// Calculate the number of parameters
302    fn calculate_n_parameters(&self, n_features: usize) -> usize {
303        // Simplified: weights (n_components-1) + means (n_components * n_features) +
304        // covariances (n_components * n_features for diagonal)
305        (self.n_components - 1)
306            + (self.n_components * n_features)
307            + (self.n_components * n_features)
308    }
309
310    /// Initialize parameters randomly
311    fn initialize_parameters(
312        &self,
313        n_features: usize,
314        rng: &mut scirs2_core::random::rngs::StdRng,
315    ) -> SklResult<Array1<f64>> {
316        let n_params = self.calculate_n_parameters(n_features);
317        let mut params = Array1::zeros(n_params);
318
319        // Initialize with small random values
320        for i in 0..n_params {
321            let normal = RandNormal::new(0.0, 0.1).map_err(|e| {
322                SklearsError::InvalidInput(format!("Normal distribution error: {}", e))
323            })?;
324            params[i] = rng.sample(normal);
325        }
326
327        Ok(params)
328    }
329
330    /// Sample momentum from multivariate normal distribution
331    fn sample_momentum(
332        &self,
333        mass_matrix: &Array1<f64>,
334        rng: &mut scirs2_core::random::rngs::StdRng,
335    ) -> SklResult<Array1<f64>> {
336        let n_params = mass_matrix.len();
337        let mut momentum = Array1::zeros(n_params);
338
339        for i in 0..n_params {
340            let std_dev = mass_matrix[i].sqrt();
341            momentum[i] = RandNormal::new(0.0, std_dev)
342                .map_err(|e| {
343                    SklearsError::InvalidInput(format!("Normal distribution error: {}", e))
344                })?
345                .sample(rng);
346        }
347
348        Ok(momentum)
349    }
350
351    /// Compute log posterior and its gradient
352    fn compute_log_posterior_and_gradient(
353        &self,
354        _X: &Array2<f64>,
355        position: &Array1<f64>,
356        _n_features: usize,
357    ) -> SklResult<(f64, Array1<f64>)> {
358        // Simplified implementation
359        // In practice, this would compute the actual log posterior of the mixture model
360        // and its gradient with respect to all parameters
361
362        let log_posterior = -0.5 * position.mapv(|x| x * x).sum(); // Simple quadratic penalty
363        let gradient = -position.clone(); // Gradient of quadratic penalty
364
365        Ok((log_posterior, gradient))
366    }
367
368    /// Compute kinetic energy
369    fn kinetic_energy(&self, momentum: &Array1<f64>, mass_matrix: &Array1<f64>) -> f64 {
370        0.5 * (momentum * momentum / mass_matrix).sum()
371    }
372
373    /// Build NUTS tree
374    fn build_tree(
375        &self,
376        position: &Array1<f64>,
377        momentum: &Array1<f64>,
378        log_posterior: f64,
379        gradient: &Array1<f64>,
380        X: &Array2<f64>,
381        n_features: usize,
382        step_size: f64,
383        mass_matrix: &Array1<f64>,
384        rng: &mut scirs2_core::random::rngs::StdRng,
385    ) -> SklResult<TreeResult> {
386        // Initialize tree with current state
387        let initial_node = TreeNode {
388            position: position.clone(),
389            momentum: momentum.clone(),
390            log_posterior,
391            gradient: gradient.clone(),
392        };
393
394        let mut tree_state = TreeState {
395            left_node: initial_node.clone(),
396            right_node: initial_node.clone(),
397            proposal: initial_node.clone(),
398            n_proposals: 1,
399            sum_momentum: momentum.clone(),
400            sum_momentum_squared: momentum.mapv(|x| x * x).sum(),
401            valid: true,
402        };
403
404        let mut tree_depth = 0;
405
406        // Build tree until U-turn or maximum depth
407        for depth in 0..self.max_tree_depth {
408            tree_depth = depth;
409
410            // Choose direction: forward or backward
411            let direction = if rng.gen::<f64>() < 0.5 { 1.0 } else { -1.0 };
412
413            // Build subtree in chosen direction
414            let subtree = self.build_subtree(
415                if direction > 0.0 {
416                    &tree_state.right_node
417                } else {
418                    &tree_state.left_node
419                },
420                direction,
421                depth,
422                step_size,
423                X,
424                n_features,
425                mass_matrix,
426                rng,
427            )?;
428
429            // Check for U-turn
430            if !subtree.valid || self.check_uturn(&tree_state, &subtree) {
431                break;
432            }
433
434            // Update tree state
435            if direction > 0.0 {
436                tree_state.right_node = subtree.right_node;
437            } else {
438                tree_state.left_node = subtree.left_node;
439            }
440
441            // Update proposal with probability proportional to number of proposals
442            let accept_prob =
443                subtree.n_proposals as f64 / (tree_state.n_proposals + subtree.n_proposals) as f64;
444            if rng.gen::<f64>() < accept_prob {
445                tree_state.proposal = subtree.proposal;
446            }
447
448            tree_state.n_proposals += subtree.n_proposals;
449            tree_state.sum_momentum = &tree_state.sum_momentum + &subtree.sum_momentum;
450            tree_state.sum_momentum_squared += subtree.sum_momentum_squared;
451        }
452
453        Ok(TreeResult {
454            proposal: tree_state.proposal,
455            valid: tree_state.valid,
456            tree_depth,
457        })
458    }
459
460    /// Build subtree
461    fn build_subtree(
462        &self,
463        node: &TreeNode,
464        direction: f64,
465        depth: usize,
466        step_size: f64,
467        X: &Array2<f64>,
468        n_features: usize,
469        mass_matrix: &Array1<f64>,
470        rng: &mut scirs2_core::random::rngs::StdRng,
471    ) -> SklResult<TreeState> {
472        if depth == 0 {
473            // Base case: single leapfrog step
474            let new_node =
475                self.leapfrog_step(node, direction * step_size, X, n_features, mass_matrix)?;
476
477            Ok(TreeState {
478                left_node: new_node.clone(),
479                right_node: new_node.clone(),
480                proposal: new_node.clone(),
481                n_proposals: 1,
482                sum_momentum: new_node.momentum.clone(),
483                sum_momentum_squared: new_node.momentum.mapv(|x| x * x).sum(),
484                valid: true,
485            })
486        } else {
487            // Recursive case: build left and right subtrees
488            let left_subtree = self.build_subtree(
489                node,
490                direction,
491                depth - 1,
492                step_size,
493                X,
494                n_features,
495                mass_matrix,
496                rng,
497            )?;
498
499            if !left_subtree.valid {
500                return Ok(left_subtree);
501            }
502
503            let right_subtree = self.build_subtree(
504                if direction > 0.0 {
505                    &left_subtree.right_node
506                } else {
507                    &left_subtree.left_node
508                },
509                direction,
510                depth - 1,
511                step_size,
512                X,
513                n_features,
514                mass_matrix,
515                rng,
516            )?;
517
518            if !right_subtree.valid {
519                return Ok(TreeState {
520                    left_node: left_subtree.left_node,
521                    right_node: left_subtree.right_node,
522                    proposal: left_subtree.proposal,
523                    n_proposals: left_subtree.n_proposals,
524                    sum_momentum: left_subtree.sum_momentum,
525                    sum_momentum_squared: left_subtree.sum_momentum_squared,
526                    valid: false,
527                });
528            }
529
530            // Combine subtrees
531            let total_proposals = left_subtree.n_proposals + right_subtree.n_proposals;
532            let proposal =
533                if rng.gen::<f64>() < (right_subtree.n_proposals as f64 / total_proposals as f64) {
534                    right_subtree.proposal
535                } else {
536                    left_subtree.proposal
537                };
538
539            Ok(TreeState {
540                left_node: if direction > 0.0 {
541                    left_subtree.left_node
542                } else {
543                    right_subtree.left_node
544                },
545                right_node: if direction > 0.0 {
546                    right_subtree.right_node
547                } else {
548                    left_subtree.right_node
549                },
550                proposal,
551                n_proposals: total_proposals,
552                sum_momentum: &left_subtree.sum_momentum + &right_subtree.sum_momentum,
553                sum_momentum_squared: left_subtree.sum_momentum_squared
554                    + right_subtree.sum_momentum_squared,
555                valid: true,
556            })
557        }
558    }
559
560    /// Perform leapfrog step
561    fn leapfrog_step(
562        &self,
563        node: &TreeNode,
564        step_size: f64,
565        X: &Array2<f64>,
566        n_features: usize,
567        mass_matrix: &Array1<f64>,
568    ) -> SklResult<TreeNode> {
569        // Half step for momentum
570        let momentum_half = &node.momentum + 0.5 * step_size * &node.gradient;
571
572        // Full step for position
573        let new_position = &node.position + step_size * (&momentum_half / mass_matrix);
574
575        // Compute gradient at new position
576        let (new_log_posterior, new_gradient) =
577            self.compute_log_posterior_and_gradient(X, &new_position, n_features)?;
578
579        // Half step for momentum
580        let new_momentum = &momentum_half + 0.5 * step_size * &new_gradient;
581
582        Ok(TreeNode {
583            position: new_position,
584            momentum: new_momentum,
585            log_posterior: new_log_posterior,
586            gradient: new_gradient,
587        })
588    }
589
590    /// Check for U-turn condition
591    fn check_uturn(&self, tree_state: &TreeState, _subtree: &TreeState) -> bool {
592        // Simplified U-turn check using momentum dot product
593        let left_momentum = &tree_state.left_node.momentum;
594        let right_momentum = &tree_state.right_node.momentum;
595        let momentum_diff = right_momentum - left_momentum;
596
597        // U-turn if momentum is pointing back towards start
598        left_momentum.dot(&momentum_diff) < 0.0 || right_momentum.dot(&momentum_diff) < 0.0
599    }
600
601    /// Adapt step size using dual averaging
602    fn adapt_step_size_dual_averaging(
603        &self,
604        current_step_size: f64,
605        accept_prob: f64,
606        iteration: usize,
607    ) -> f64 {
608        let delta = self.target_accept_rate - accept_prob;
609        let gamma = 0.05;
610        let t0 = 10.0;
611        let kappa = 0.75;
612
613        let eta = 1.0 / (iteration as f64 + t0);
614        let log_step_size = current_step_size.ln() + eta * delta;
615
616        (log_step_size - gamma * (iteration as f64).powf(-kappa) * delta).exp()
617    }
618
619    /// Store sample in result arrays
620    fn store_sample(
621        &self,
622        sample_idx: usize,
623        position: &Array1<f64>,
624        n_features: usize,
625        weights_samples: &mut Array2<f64>,
626        means_samples: &mut Array3<f64>,
627        covariances_samples: &mut Vec<Array3<f64>>,
628        log_posterior_samples: &mut Array1<f64>,
629        tree_depth_samples: &mut Array1<usize>,
630        tree_depth: usize,
631    ) -> SklResult<()> {
632        // Decode parameters from position vector
633        // This is simplified - in practice would properly decode mixture parameters
634        let (weights, means, covariances) = self.decode_parameters(position, n_features)?;
635
636        // Store weights
637        for k in 0..self.n_components {
638            weights_samples[[sample_idx, k]] = weights[k];
639        }
640
641        // Store means
642        for k in 0..self.n_components {
643            for j in 0..n_features {
644                means_samples[[sample_idx, k, j]] = means[[k, j]];
645            }
646        }
647
648        // Store covariances (simplified as diagonal)
649        for k in 0..self.n_components {
650            for i in 0..n_features {
651                for j in 0..n_features {
652                    covariances_samples[k][[sample_idx, i, j]] =
653                        if i == j { covariances[k] } else { 0.0 };
654                }
655            }
656        }
657
658        // Store log posterior (simplified)
659        log_posterior_samples[sample_idx] = -0.5 * position.mapv(|x| x * x).sum();
660
661        // Store tree depth
662        tree_depth_samples[sample_idx] = tree_depth;
663
664        Ok(())
665    }
666
667    /// Decode parameters from position vector
668    fn decode_parameters(
669        &self,
670        _position: &Array1<f64>,
671        n_features: usize,
672    ) -> SklResult<(Array1<f64>, Array2<f64>, Array1<f64>)> {
673        // Simplified parameter decoding
674        let weights = Array1::ones(self.n_components) / self.n_components as f64;
675        let means = Array2::zeros((self.n_components, n_features));
676        let covariances = Array1::ones(self.n_components);
677
678        // In practice, this would properly decode the constrained parameters
679        // from the unconstrained position vector
680
681        Ok((weights, means, covariances))
682    }
683}
684
685/// Result of building a NUTS tree
686#[derive(Debug)]
687struct TreeResult {
688    proposal: TreeNode,
689    valid: bool,
690    tree_depth: usize,
691}
692
693impl Default for NUTSSampler {
694    fn default() -> Self {
695        Self::new()
696    }
697}
698
699impl NUTSResult {
700    /// Get the weight samples
701    pub fn weights_samples(&self) -> &Array2<f64> {
702        &self.weights_samples
703    }
704
705    /// Get the mean samples
706    pub fn means_samples(&self) -> &Array3<f64> {
707        &self.means_samples
708    }
709
710    /// Get the covariance samples
711    pub fn covariances_samples(&self) -> &[Array3<f64>] {
712        &self.covariances_samples
713    }
714
715    /// Get the log posterior samples
716    pub fn log_posterior_samples(&self) -> &Array1<f64> {
717        &self.log_posterior_samples
718    }
719
720    /// Get the acceptance rate
721    pub fn acceptance_rate(&self) -> f64 {
722        self.acceptance_rate
723    }
724
725    /// Get the final step size
726    pub fn step_size_final(&self) -> f64 {
727        self.step_size_final
728    }
729
730    /// Get the number of divergent transitions
731    pub fn n_divergent(&self) -> usize {
732        self.n_divergent
733    }
734
735    /// Get the tree depth samples
736    pub fn tree_depth_samples(&self) -> &Array1<usize> {
737        &self.tree_depth_samples
738    }
739
740    /// Compute posterior means
741    pub fn posterior_means(&self) -> SklResult<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
742        let n_samples = self.weights_samples.shape()[0];
743        let n_components = self.weights_samples.shape()[1];
744        let n_features = self.means_samples.shape()[2];
745
746        // Compute mean weights
747        let mut mean_weights = Array1::zeros(n_components);
748        for k in 0..n_components {
749            mean_weights[k] = self.weights_samples.column(k).mean().unwrap_or(0.0);
750        }
751
752        // Compute mean means
753        let mut mean_means = Array2::zeros((n_components, n_features));
754        for k in 0..n_components {
755            for j in 0..n_features {
756                let values: Vec<f64> = (0..n_samples)
757                    .map(|i| self.means_samples[[i, k, j]])
758                    .collect();
759                mean_means[[k, j]] = values.iter().sum::<f64>() / n_samples as f64;
760            }
761        }
762
763        // Compute mean covariances
764        let mut mean_covariances = Vec::new();
765        for k in 0..n_components {
766            let mut mean_cov = Array2::zeros((n_features, n_features));
767            for i in 0..n_features {
768                for j in 0..n_features {
769                    let values: Vec<f64> = (0..n_samples)
770                        .map(|s| self.covariances_samples[k][[s, i, j]])
771                        .collect();
772                    mean_cov[[i, j]] = values.iter().sum::<f64>() / n_samples as f64;
773                }
774            }
775            mean_covariances.push(mean_cov);
776        }
777
778        Ok((mean_weights, mean_means, mean_covariances))
779    }
780
781    /// Compute credible intervals
782    pub fn credible_intervals(&self, alpha: f64) -> SklResult<(Array2<f64>, Array3<f64>)> {
783        let n_components = self.weights_samples.shape()[1];
784        let n_features = self.means_samples.shape()[2];
785
786        // Weight credible intervals
787        let mut weight_intervals = Array2::zeros((n_components, 2));
788        for k in 0..n_components {
789            let mut values: Vec<f64> = self.weights_samples.column(k).to_vec();
790            values.sort_by(|a, b| a.partial_cmp(b).unwrap());
791            let lower_idx = ((alpha / 2.0) * values.len() as f64) as usize;
792            let upper_idx = ((1.0 - alpha / 2.0) * values.len() as f64) as usize;
793            weight_intervals[[k, 0]] = values[lower_idx];
794            weight_intervals[[k, 1]] = values[upper_idx.min(values.len() - 1)];
795        }
796
797        // Mean credible intervals
798        let mut mean_intervals = Array3::zeros((n_components, n_features, 2));
799        for k in 0..n_components {
800            for j in 0..n_features {
801                let mut values: Vec<f64> = (0..self.means_samples.shape()[0])
802                    .map(|i| self.means_samples[[i, k, j]])
803                    .collect();
804                values.sort_by(|a, b| a.partial_cmp(b).unwrap());
805                let lower_idx = ((alpha / 2.0) * values.len() as f64) as usize;
806                let upper_idx = ((1.0 - alpha / 2.0) * values.len() as f64) as usize;
807                mean_intervals[[k, j, 0]] = values[lower_idx];
808                mean_intervals[[k, j, 1]] = values[upper_idx.min(values.len() - 1)];
809            }
810        }
811
812        Ok((weight_intervals, mean_intervals))
813    }
814}
815
816#[allow(non_snake_case)]
817#[cfg(test)]
818mod tests {
819    use super::*;
820    use approx::assert_relative_eq;
821    use scirs2_core::ndarray::array;
822
823    #[test]
824    #[allow(non_snake_case)]
825    fn test_nuts_sampler_basic() {
826        let X = array![[0.0, 0.0], [1.0, 1.0], [10.0, 10.0], [11.0, 11.0]];
827
828        let nuts = NUTSSampler::new()
829            .n_components(2)
830            .n_samples(10)
831            .n_warmup(5)
832            .max_tree_depth(3)
833            .random_state(42);
834
835        let result = nuts.sample(&X.view()).unwrap();
836
837        assert_eq!(result.weights_samples.shape(), &[10, 2]);
838        assert_eq!(result.means_samples.shape(), &[10, 2, 2]);
839        assert_eq!(result.covariances_samples.len(), 2);
840        assert!(result.acceptance_rate >= 0.0 && result.acceptance_rate <= 1.0);
841    }
842
843    #[test]
844    fn test_nuts_sampler_builder() {
845        let nuts = NUTSSampler::builder()
846            .n_components(3)
847            .n_samples(100)
848            .n_warmup(50)
849            .step_size(0.05)
850            .target_accept_rate(0.85)
851            .max_tree_depth(12)
852            .adapt_step_size(true)
853            .adapt_mass_matrix(false)
854            .random_state(123)
855            .build();
856
857        assert_eq!(nuts.n_components, 3);
858        assert_eq!(nuts.n_samples, 100);
859        assert_eq!(nuts.n_warmup, 50);
860        assert_relative_eq!(nuts.step_size, 0.05);
861        assert_relative_eq!(nuts.target_accept_rate, 0.85);
862        assert_eq!(nuts.max_tree_depth, 12);
863        assert!(nuts.adapt_step_size);
864        assert!(!nuts.adapt_mass_matrix);
865    }
866
867    #[test]
868    #[allow(non_snake_case)]
869    fn test_nuts_result_analysis() {
870        let X = array![[0.0], [1.0], [2.0]];
871
872        let nuts = NUTSSampler::new()
873            .n_components(1)
874            .n_samples(5)
875            .n_warmup(2)
876            .random_state(42);
877
878        let result = nuts.sample(&X.view()).unwrap();
879
880        // Test posterior means computation
881        let (mean_weights, mean_means, mean_covariances) = result.posterior_means().unwrap();
882        assert_eq!(mean_weights.len(), 1);
883        assert_eq!(mean_means.shape(), &[1, 1]);
884        assert_eq!(mean_covariances.len(), 1);
885
886        // Test credible intervals
887        let (weight_intervals, mean_intervals) = result.credible_intervals(0.05).unwrap();
888        assert_eq!(weight_intervals.shape(), &[1, 2]);
889        assert_eq!(mean_intervals.shape(), &[1, 1, 2]);
890    }
891
892    #[test]
893    fn test_kinetic_energy() {
894        let nuts = NUTSSampler::new();
895        let momentum = array![1.0, 2.0, 3.0];
896        let mass_matrix = array![1.0, 1.0, 1.0];
897
898        let ke = nuts.kinetic_energy(&momentum, &mass_matrix);
899        assert_relative_eq!(ke, 7.0); // 0.5 * (1 + 4 + 9)
900    }
901
902    #[test]
903    fn test_step_size_adaptation() {
904        let nuts = NUTSSampler::new().target_accept_rate(0.8);
905
906        let step_size = nuts.adapt_step_size_dual_averaging(0.1, 0.6, 10);
907        assert!(step_size > 0.0);
908
909        let step_size2 = nuts.adapt_step_size_dual_averaging(0.1, 0.9, 10);
910        assert!(step_size2 > 0.0);
911    }
912}