Skip to main content

scirs2_stats/mcmc/
gibbs.rs

1//! Gibbs sampling for MCMC
2//!
3//! Gibbs sampling is a MCMC method for sampling from multivariate distributions
4//! when direct sampling is difficult but conditional distributions are available.
5
6use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::validation::*;
9use scirs2_core::{Rng, RngExt};
10use std::fmt::Debug;
11
12/// Conditional distribution trait for Gibbs sampling
13pub trait ConditionalDistribution: Send + Sync {
14    /// Sample from the conditional distribution P(X_i | X_{-i})
15    ///
16    /// # Arguments
17    /// * `current_state` - Current values of all variables
18    /// * `variable_index` - Index of the variable to sample
19    /// * `rng` - Random number generator
20    ///
21    /// # Returns
22    /// New value for the variable at `variable_index`
23    fn sample_conditional<R: Rng + ?Sized>(
24        &self,
25        current_state: &Array1<f64>,
26        variable_index: usize,
27        rng: &mut R,
28    ) -> Result<f64>;
29
30    /// Get the dimensionality of the distribution
31    fn dim(&self) -> usize;
32
33    /// Optionally compute log density for monitoring
34    fn log_density(&self, x: &Array1<f64>) -> Option<f64> {
35        None
36    }
37}
38
39/// Gibbs sampler
40pub struct GibbsSampler<C: ConditionalDistribution> {
41    /// Conditional distributions
42    pub conditionals: C,
43    /// Current state
44    pub current: Array1<f64>,
45    /// Number of samples generated
46    pub n_samples_: usize,
47    /// Variable update order (None for sequential, Some for custom order)
48    pub update_order: Option<Vec<usize>>,
49}
50
51impl<C: ConditionalDistribution> GibbsSampler<C> {
52    /// Create a new Gibbs sampler
53    pub fn new(conditionals: C, initial: Array1<f64>) -> Result<Self> {
54        checkarray_finite(&initial, "initial")?;
55        if initial.len() != conditionals.dim() {
56            return Err(StatsError::DimensionMismatch(format!(
57                "initial dimension ({}) must match conditionals dimension ({})",
58                initial.len(),
59                conditionals.dim()
60            )));
61        }
62
63        Ok(Self {
64            conditionals,
65            current: initial,
66            n_samples_: 0,
67            update_order: None,
68        })
69    }
70
71    /// Set custom variable update order
72    pub fn with_update_order(mut self, order: Vec<usize>) -> Result<Self> {
73        if order.len() != self.conditionals.dim() {
74            return Err(StatsError::InvalidArgument(
75                "Update order length must match dimension".to_string(),
76            ));
77        }
78
79        // Check that all indices are valid and unique
80        let mut sorted_order = order.clone();
81        sorted_order.sort_unstable();
82        for (i, &idx) in sorted_order.iter().enumerate() {
83            if idx != i {
84                return Err(StatsError::InvalidArgument(
85                    "Update order must contain each index exactly once".to_string(),
86                ));
87            }
88        }
89
90        self.update_order = Some(order);
91        Ok(self)
92    }
93
94    /// Perform one full sweep of Gibbs sampling
95    pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<Array1<f64>> {
96        let dim = self.current.len();
97
98        // Determine update order
99        let order = match &self.update_order {
100            Some(order) => order.clone(),
101            None => (0..dim).collect(),
102        };
103
104        // Update each variable in order
105        for &var_idx in &order {
106            let new_value = self
107                .conditionals
108                .sample_conditional(&self.current, var_idx, rng)?;
109            self.current[var_idx] = new_value;
110        }
111
112        self.n_samples_ += 1;
113        Ok(self.current.clone())
114    }
115
116    /// Sample multiple states
117    pub fn sample<R: Rng + ?Sized>(
118        &mut self,
119        n_samples_: usize,
120        rng: &mut R,
121    ) -> Result<Array2<f64>> {
122        let dim = self.current.len();
123        let mut samples = Array2::zeros((n_samples_, dim));
124
125        for i in 0..n_samples_ {
126            let sample = self.step(rng)?;
127            samples.row_mut(i).assign(&sample);
128        }
129
130        Ok(samples)
131    }
132
133    /// Sample with burn-in period
134    pub fn sample_with_burnin<R: Rng + ?Sized>(
135        &mut self,
136        n_samples_: usize,
137        burnin: usize,
138        rng: &mut R,
139    ) -> Result<Array2<f64>> {
140        check_positive(burnin, "burnin")?;
141
142        // Burn-in period
143        for _ in 0..burnin {
144            self.step(rng)?;
145        }
146
147        // Collect _samples
148        self.sample(n_samples_, rng)
149    }
150
151    /// Sample with thinning to reduce autocorrelation
152    pub fn sample_thinned<R: Rng + ?Sized>(
153        &mut self,
154        n_samples_: usize,
155        thin: usize,
156        rng: &mut R,
157    ) -> Result<Array2<f64>> {
158        check_positive(thin, "thin")?;
159
160        let dim = self.current.len();
161        let mut samples = Array2::zeros((n_samples_, dim));
162
163        for i in 0..n_samples_ {
164            // Take thin steps but only keep the last one
165            for _ in 0..thin {
166                self.step(rng)?;
167            }
168            samples.row_mut(i).assign(&self.current);
169        }
170
171        Ok(samples)
172    }
173}
174
175/// Multivariate normal Gibbs sampler
176///
177/// For sampling from a multivariate normal distribution where each variable
178/// given all others follows a normal distribution.
179#[derive(Debug, Clone)]
180pub struct MultivariateNormalGibbs {
181    /// Mean vector
182    pub mean: Array1<f64>,
183    /// Precision matrix (inverse covariance)
184    pub precision: Array2<f64>,
185}
186
187impl MultivariateNormalGibbs {
188    /// Create a new multivariate normal Gibbs sampler
189    pub fn new(mean: Array1<f64>, covariance: Array2<f64>) -> Result<Self> {
190        checkarray_finite(&mean, "mean")?;
191        checkarray_finite(&covariance, "covariance")?;
192
193        if covariance.nrows() != mean.len() || covariance.ncols() != mean.len() {
194            return Err(StatsError::DimensionMismatch(format!(
195                "covariance shape ({}, {}) must be ({}, {})",
196                covariance.nrows(),
197                covariance.ncols(),
198                mean.len(),
199                mean.len()
200            )));
201        }
202
203        // Compute precision matrix
204        let precision = scirs2_linalg::inv(&covariance.view(), None).map_err(|e| {
205            StatsError::ComputationError(format!("Failed to invert covariance: {}", e))
206        })?;
207
208        Ok(Self { mean, precision })
209    }
210
211    /// Create from precision matrix directly
212    pub fn from_precision(mean: Array1<f64>, precision: Array2<f64>) -> Result<Self> {
213        checkarray_finite(&mean, "mean")?;
214        checkarray_finite(&precision, "precision")?;
215
216        if precision.nrows() != mean.len() || precision.ncols() != mean.len() {
217            return Err(StatsError::DimensionMismatch(format!(
218                "precision shape ({}, {}) must be ({}, {})",
219                precision.nrows(),
220                precision.ncols(),
221                mean.len(),
222                mean.len()
223            )));
224        }
225
226        Ok(Self { mean, precision })
227    }
228}
229
230impl ConditionalDistribution for MultivariateNormalGibbs {
231    fn sample_conditional<R: Rng + ?Sized>(
232        &self,
233        current_state: &Array1<f64>,
234        variable_index: usize,
235        rng: &mut R,
236    ) -> Result<f64> {
237        let dim = self.mean.len();
238        if variable_index >= dim {
239            return Err(StatsError::InvalidArgument(format!(
240                "variable_index ({}) must be less than dimension ({})",
241                variable_index, dim
242            )));
243        }
244
245        // For multivariate normal, conditional distribution is:
246        // X_i | X_{-i} ~ Normal(mu_i + Sigma_{i,-i} * Sigma_{-i,-i}^{-1} * (X_{-i} - mu_{-i}), Sigma_{ii|{-i}})
247        // Where Sigma_{ii|{-i}} = 1 / Precision_{ii}
248
249        let precision_ii = self.precision[[variable_index, variable_index]];
250        if precision_ii.abs() < f64::EPSILON {
251            return Err(StatsError::ComputationError(
252                "Precision matrix must have positive diagonal elements".to_string(),
253            ));
254        }
255
256        // Conditional variance
257        let conditional_variance = 1.0 / precision_ii;
258        let conditional_std = conditional_variance.sqrt();
259
260        // Conditional mean
261        let mut sum = 0.0;
262        for j in 0..dim {
263            if j != variable_index {
264                sum += self.precision[[variable_index, j]] * (current_state[j] - self.mean[j]);
265            }
266        }
267        let conditional_mean = self.mean[variable_index] - sum / precision_ii;
268
269        // Sample from normal distribution
270        use scirs2_core::random::{Distribution, Normal};
271        let normal = Normal::new(conditional_mean, conditional_std).map_err(|e| {
272            StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
273        })?;
274
275        Ok(normal.sample(rng))
276    }
277
278    fn dim(&self) -> usize {
279        self.mean.len()
280    }
281
282    fn log_density(&self, x: &Array1<f64>) -> Option<f64> {
283        let diff = x - &self.mean;
284        let quad_form = diff.dot(&self.precision.dot(&diff));
285
286        // Compute log determinant of precision
287        let det = scirs2_linalg::det(&self.precision.view(), None).ok()?;
288        if det <= 0.0 {
289            return None;
290        }
291
292        let d = self.mean.len() as f64;
293        let log_norm_const = 0.5 * (det.ln() - d * (2.0 * std::f64::consts::PI).ln());
294
295        Some(log_norm_const - 0.5 * quad_form)
296    }
297}
298
299/// Gaussian mixture model Gibbs sampler
300///
301/// Samples component assignments and parameters for a Gaussian mixture model
302#[derive(Debug, Clone)]
303pub struct GaussianMixtureGibbs {
304    /// Current component means
305    pub means: Array2<f64>,
306    /// Current component precisions
307    pub precisions: Vec<Array2<f64>>,
308    /// Current mixing weights
309    pub weights: Array1<f64>,
310    /// Data points
311    pub data: Array2<f64>,
312    /// Current component assignments
313    pub assignments: Array1<usize>,
314    /// Number of components
315    pub n_components: usize,
316    /// Hyperparameters for priors
317    pub prior_mean: Array1<f64>,
318    pub prior_precision: Array2<f64>,
319    pub prior_alpha: Array1<f64>, // Dirichlet prior for weights
320}
321
322impl GaussianMixtureGibbs {
323    /// Create a new Gaussian mixture Gibbs sampler
324    pub fn new(
325        data: Array2<f64>,
326        n_components: usize,
327        prior_mean: Array1<f64>,
328        prior_precision: Array2<f64>,
329        prior_alpha: Array1<f64>,
330    ) -> Result<Self> {
331        checkarray_finite(&data, "data")?;
332        check_positive(n_components, "n_components")?;
333        checkarray_finite(&prior_mean, "prior_mean")?;
334        checkarray_finite(&prior_precision, "prior_precision")?;
335        checkarray_finite(&prior_alpha, "prior_alpha")?;
336
337        let (n_samples_, dim) = data.dim();
338
339        if prior_alpha.len() != n_components {
340            return Err(StatsError::DimensionMismatch(format!(
341                "prior_alpha length ({}) must equal n_components ({})",
342                prior_alpha.len(),
343                n_components
344            )));
345        }
346
347        // Initialize parameters
348        let means = Array2::zeros((n_components, dim));
349        let precisions = vec![Array2::eye(dim); n_components];
350        let weights = Array1::from_elem(n_components, 1.0 / n_components as f64);
351        let assignments = Array1::zeros(n_samples_);
352
353        Ok(Self {
354            means,
355            precisions,
356            weights,
357            data,
358            assignments,
359            n_components,
360            prior_mean,
361            prior_precision,
362            prior_alpha,
363        })
364    }
365
366    /// Perform one step of Gibbs sampling for GMM
367    pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
368        // 1. Sample component assignments
369        self.sample_assignments(rng)?;
370
371        // 2. Sample component parameters given assignments
372        self.sample_parameters(rng)?;
373
374        // 3. Sample mixing weights
375        self.sample_weights(rng)?;
376
377        Ok(())
378    }
379
380    /// Sample component assignments for each data point
381    fn sample_assignments<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
382        for i in 0..self.data.nrows() {
383            let data_point = self.data.row(i);
384            let mut log_probs = Array1::zeros(self.n_components);
385
386            // Compute log probabilities for each component
387            for k in 0..self.n_components {
388                let mean_k = self.means.row(k);
389                let precision_k = &self.precisions[k];
390
391                let diff = &data_point.to_owned() - &mean_k.to_owned();
392                let quad_form = diff.dot(&precision_k.dot(&diff));
393
394                // Log determinant of precision
395                let det = scirs2_linalg::det(&precision_k.view(), None).map_err(|e| {
396                    StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
397                })?;
398
399                if det <= 0.0 {
400                    return Err(StatsError::ComputationError(
401                        "Precision matrix must be positive definite".to_string(),
402                    ));
403                }
404
405                let log_likelihood = 0.5 * det.ln() - 0.5 * quad_form;
406                log_probs[k] = self.weights[k].ln() + log_likelihood;
407            }
408
409            // Convert to probabilities and sample
410            let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
411            let mut probs = log_probs.mapv(|x| (x - max_log_prob).exp());
412            let prob_sum = probs.sum();
413            probs /= prob_sum;
414
415            // Sample from categorical distribution
416            let u: f64 = rng.random();
417            let mut cumsum = 0.0;
418            let mut selected = 0;
419
420            for (k, &p) in probs.iter().enumerate() {
421                cumsum += p;
422                if u <= cumsum {
423                    selected = k;
424                    break;
425                }
426            }
427
428            self.assignments[i] = selected;
429        }
430
431        Ok(())
432    }
433
434    /// Sample component parameters given assignments
435    fn sample_parameters<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
436        for k in 0..self.n_components {
437            // Find data points assigned to component k
438            let assigned_indices: Vec<usize> = self
439                .assignments
440                .iter()
441                .enumerate()
442                .filter_map(|(i, &assignment)| if assignment == k { Some(i) } else { None })
443                .collect();
444
445            if assigned_indices.is_empty() {
446                // No data assigned to this component, sample from prior
447                self.sample_from_prior(k, rng)?;
448            } else {
449                // Sample posterior given assigned data
450                self.sample_posterior(k, &assigned_indices, rng)?;
451            }
452        }
453
454        Ok(())
455    }
456
457    /// Sample parameters from prior when no data is assigned
458    fn sample_from_prior<R: Rng + ?Sized>(&mut self, component: usize, rng: &mut R) -> Result<()> {
459        // Sample mean from prior
460        use scirs2_core::random::{Distribution, Normal};
461
462        let dim = self.prior_mean.len();
463        let mut new_mean = Array1::zeros(dim);
464
465        // For simplicity, assume diagonal prior precision
466        for i in 0..dim {
467            let variance = 1.0 / self.prior_precision[[i, i]];
468            let std = variance.sqrt();
469            let normal = Normal::new(self.prior_mean[i], std).map_err(|e| {
470                StatsError::ComputationError(format!("Failed to create normal: {}", e))
471            })?;
472            new_mean[i] = normal.sample(rng);
473        }
474
475        self.means.row_mut(component).assign(&new_mean);
476
477        // For precision, use prior (simplified - in practice would sample from Wishart)
478        self.precisions[component] = self.prior_precision.clone();
479
480        Ok(())
481    }
482
483    /// Sample posterior parameters given assigned data
484    fn sample_posterior<R: Rng + ?Sized>(
485        &mut self,
486        component: usize,
487        assigned_indices: &[usize],
488        rng: &mut R,
489    ) -> Result<()> {
490        let n_assigned = assigned_indices.len();
491        let dim = self.prior_mean.len();
492
493        // Compute sample mean
494        let mut sample_mean = Array1::zeros(dim);
495        for &i in assigned_indices {
496            sample_mean = sample_mean + self.data.row(i);
497        }
498        sample_mean /= n_assigned as f64;
499
500        // Posterior parameters for mean (assuming identity precision for simplicity)
501        let posterior_precision = &self.prior_precision + Array2::eye(dim) * n_assigned as f64;
502        let posterior_mean = {
503            let prior_contrib = self.prior_precision.dot(&self.prior_mean);
504            let data_contrib = Array1::from_elem(dim, n_assigned as f64) * &sample_mean;
505            let precision_inv =
506                scirs2_linalg::inv(&posterior_precision.view(), None).map_err(|e| {
507                    StatsError::ComputationError(format!("Failed to invert precision: {}", e))
508                })?;
509            precision_inv.dot(&(prior_contrib + data_contrib))
510        };
511
512        // Sample new mean
513        use scirs2_core::random::{Distribution, Normal};
514        let mut new_mean = Array1::zeros(dim);
515
516        for i in 0..dim {
517            let variance = 1.0 / posterior_precision[[i, i]];
518            let std = variance.sqrt();
519            let normal = Normal::new(posterior_mean[i], std).map_err(|e| {
520                StatsError::ComputationError(format!("Failed to create normal: {}", e))
521            })?;
522            new_mean[i] = normal.sample(rng);
523        }
524
525        self.means.row_mut(component).assign(&new_mean);
526
527        // Update precision (simplified)
528        self.precisions[component] = posterior_precision;
529
530        Ok(())
531    }
532
533    /// Sample mixing weights from Dirichlet posterior
534    fn sample_weights<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
535        // Count assignments to each component
536        let mut counts = Array1::<f64>::zeros(self.n_components);
537        for &assignment in self.assignments.iter() {
538            counts[assignment] += 1.0;
539        }
540
541        // Posterior parameters for Dirichlet
542        let posterior_alpha = &self.prior_alpha + &counts;
543
544        // Sample from Dirichlet (using Gamma sampling)
545        use scirs2_core::random::{Distribution, Gamma};
546        let mut gamma_samples = Array1::zeros(self.n_components);
547
548        for k in 0..self.n_components {
549            let gamma = Gamma::new(posterior_alpha[k], 1.0).map_err(|e| {
550                StatsError::ComputationError(format!("Failed to create Gamma: {}", e))
551            })?;
552            gamma_samples[k] = gamma.sample(rng);
553        }
554
555        // Normalize to get Dirichlet sample
556        let sum = gamma_samples.sum();
557        self.weights = gamma_samples / sum;
558
559        Ok(())
560    }
561}
562
563/// Blocked Gibbs sampler for improved efficiency
564///
565/// Updates blocks of variables together rather than one at a time
566pub struct BlockedGibbsSampler<C: ConditionalDistribution> {
567    /// Base Gibbs sampler
568    pub sampler: GibbsSampler<C>,
569    /// Variable blocks (each inner vec contains indices of variables to update together)
570    pub blocks: Vec<Vec<usize>>,
571}
572
573impl<C: ConditionalDistribution> BlockedGibbsSampler<C> {
574    /// Create a new blocked Gibbs sampler
575    pub fn new(conditionals: C, initial: Array1<f64>, blocks: Vec<Vec<usize>>) -> Result<Self> {
576        let sampler = GibbsSampler::new(conditionals, initial)?;
577
578        // Validate blocks
579        let dim = sampler.conditionals.dim();
580        let mut all_indices = Vec::new();
581        for block in &blocks {
582            for &idx in block {
583                if idx >= dim {
584                    return Err(StatsError::InvalidArgument(format!(
585                        "Block index {} exceeds dimension {}",
586                        idx, dim
587                    )));
588                }
589                all_indices.push(idx);
590            }
591        }
592
593        all_indices.sort_unstable();
594        all_indices.dedup();
595        if all_indices.len() != dim {
596            return Err(StatsError::InvalidArgument(
597                "Blocks must cover all variables exactly once".to_string(),
598            ));
599        }
600
601        Ok(Self { sampler, blocks })
602    }
603
604    /// Perform one step of blocked Gibbs sampling
605    pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<Array1<f64>> {
606        // Update each block
607        for block in &self.blocks {
608            for &var_idx in block {
609                let new_value = self.sampler.conditionals.sample_conditional(
610                    &self.sampler.current,
611                    var_idx,
612                    rng,
613                )?;
614                self.sampler.current[var_idx] = new_value;
615            }
616        }
617
618        self.sampler.n_samples_ += 1;
619        Ok(self.sampler.current.clone())
620    }
621
622    /// Sample multiple states
623    pub fn sample<R: Rng + ?Sized>(
624        &mut self,
625        n_samples_: usize,
626        rng: &mut R,
627    ) -> Result<Array2<f64>> {
628        let dim = self.sampler.current.len();
629        let mut samples = Array2::zeros((n_samples_, dim));
630
631        for i in 0..n_samples_ {
632            let sample = self.step(rng)?;
633            samples.row_mut(i).assign(&sample);
634        }
635
636        Ok(samples)
637    }
638}