1use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::validation::*;
9use scirs2_core::{Rng, RngExt};
10use std::fmt::Debug;
11
12pub trait ConditionalDistribution: Send + Sync {
14 fn sample_conditional<R: Rng + ?Sized>(
24 &self,
25 current_state: &Array1<f64>,
26 variable_index: usize,
27 rng: &mut R,
28 ) -> Result<f64>;
29
30 fn dim(&self) -> usize;
32
33 fn log_density(&self, x: &Array1<f64>) -> Option<f64> {
35 None
36 }
37}
38
39pub struct GibbsSampler<C: ConditionalDistribution> {
41 pub conditionals: C,
43 pub current: Array1<f64>,
45 pub n_samples_: usize,
47 pub update_order: Option<Vec<usize>>,
49}
50
51impl<C: ConditionalDistribution> GibbsSampler<C> {
52 pub fn new(conditionals: C, initial: Array1<f64>) -> Result<Self> {
54 checkarray_finite(&initial, "initial")?;
55 if initial.len() != conditionals.dim() {
56 return Err(StatsError::DimensionMismatch(format!(
57 "initial dimension ({}) must match conditionals dimension ({})",
58 initial.len(),
59 conditionals.dim()
60 )));
61 }
62
63 Ok(Self {
64 conditionals,
65 current: initial,
66 n_samples_: 0,
67 update_order: None,
68 })
69 }
70
71 pub fn with_update_order(mut self, order: Vec<usize>) -> Result<Self> {
73 if order.len() != self.conditionals.dim() {
74 return Err(StatsError::InvalidArgument(
75 "Update order length must match dimension".to_string(),
76 ));
77 }
78
79 let mut sorted_order = order.clone();
81 sorted_order.sort_unstable();
82 for (i, &idx) in sorted_order.iter().enumerate() {
83 if idx != i {
84 return Err(StatsError::InvalidArgument(
85 "Update order must contain each index exactly once".to_string(),
86 ));
87 }
88 }
89
90 self.update_order = Some(order);
91 Ok(self)
92 }
93
94 pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<Array1<f64>> {
96 let dim = self.current.len();
97
98 let order = match &self.update_order {
100 Some(order) => order.clone(),
101 None => (0..dim).collect(),
102 };
103
104 for &var_idx in &order {
106 let new_value = self
107 .conditionals
108 .sample_conditional(&self.current, var_idx, rng)?;
109 self.current[var_idx] = new_value;
110 }
111
112 self.n_samples_ += 1;
113 Ok(self.current.clone())
114 }
115
116 pub fn sample<R: Rng + ?Sized>(
118 &mut self,
119 n_samples_: usize,
120 rng: &mut R,
121 ) -> Result<Array2<f64>> {
122 let dim = self.current.len();
123 let mut samples = Array2::zeros((n_samples_, dim));
124
125 for i in 0..n_samples_ {
126 let sample = self.step(rng)?;
127 samples.row_mut(i).assign(&sample);
128 }
129
130 Ok(samples)
131 }
132
133 pub fn sample_with_burnin<R: Rng + ?Sized>(
135 &mut self,
136 n_samples_: usize,
137 burnin: usize,
138 rng: &mut R,
139 ) -> Result<Array2<f64>> {
140 check_positive(burnin, "burnin")?;
141
142 for _ in 0..burnin {
144 self.step(rng)?;
145 }
146
147 self.sample(n_samples_, rng)
149 }
150
151 pub fn sample_thinned<R: Rng + ?Sized>(
153 &mut self,
154 n_samples_: usize,
155 thin: usize,
156 rng: &mut R,
157 ) -> Result<Array2<f64>> {
158 check_positive(thin, "thin")?;
159
160 let dim = self.current.len();
161 let mut samples = Array2::zeros((n_samples_, dim));
162
163 for i in 0..n_samples_ {
164 for _ in 0..thin {
166 self.step(rng)?;
167 }
168 samples.row_mut(i).assign(&self.current);
169 }
170
171 Ok(samples)
172 }
173}
174
175#[derive(Debug, Clone)]
180pub struct MultivariateNormalGibbs {
181 pub mean: Array1<f64>,
183 pub precision: Array2<f64>,
185}
186
187impl MultivariateNormalGibbs {
188 pub fn new(mean: Array1<f64>, covariance: Array2<f64>) -> Result<Self> {
190 checkarray_finite(&mean, "mean")?;
191 checkarray_finite(&covariance, "covariance")?;
192
193 if covariance.nrows() != mean.len() || covariance.ncols() != mean.len() {
194 return Err(StatsError::DimensionMismatch(format!(
195 "covariance shape ({}, {}) must be ({}, {})",
196 covariance.nrows(),
197 covariance.ncols(),
198 mean.len(),
199 mean.len()
200 )));
201 }
202
203 let precision = scirs2_linalg::inv(&covariance.view(), None).map_err(|e| {
205 StatsError::ComputationError(format!("Failed to invert covariance: {}", e))
206 })?;
207
208 Ok(Self { mean, precision })
209 }
210
211 pub fn from_precision(mean: Array1<f64>, precision: Array2<f64>) -> Result<Self> {
213 checkarray_finite(&mean, "mean")?;
214 checkarray_finite(&precision, "precision")?;
215
216 if precision.nrows() != mean.len() || precision.ncols() != mean.len() {
217 return Err(StatsError::DimensionMismatch(format!(
218 "precision shape ({}, {}) must be ({}, {})",
219 precision.nrows(),
220 precision.ncols(),
221 mean.len(),
222 mean.len()
223 )));
224 }
225
226 Ok(Self { mean, precision })
227 }
228}
229
230impl ConditionalDistribution for MultivariateNormalGibbs {
231 fn sample_conditional<R: Rng + ?Sized>(
232 &self,
233 current_state: &Array1<f64>,
234 variable_index: usize,
235 rng: &mut R,
236 ) -> Result<f64> {
237 let dim = self.mean.len();
238 if variable_index >= dim {
239 return Err(StatsError::InvalidArgument(format!(
240 "variable_index ({}) must be less than dimension ({})",
241 variable_index, dim
242 )));
243 }
244
245 let precision_ii = self.precision[[variable_index, variable_index]];
250 if precision_ii.abs() < f64::EPSILON {
251 return Err(StatsError::ComputationError(
252 "Precision matrix must have positive diagonal elements".to_string(),
253 ));
254 }
255
256 let conditional_variance = 1.0 / precision_ii;
258 let conditional_std = conditional_variance.sqrt();
259
260 let mut sum = 0.0;
262 for j in 0..dim {
263 if j != variable_index {
264 sum += self.precision[[variable_index, j]] * (current_state[j] - self.mean[j]);
265 }
266 }
267 let conditional_mean = self.mean[variable_index] - sum / precision_ii;
268
269 use scirs2_core::random::{Distribution, Normal};
271 let normal = Normal::new(conditional_mean, conditional_std).map_err(|e| {
272 StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
273 })?;
274
275 Ok(normal.sample(rng))
276 }
277
278 fn dim(&self) -> usize {
279 self.mean.len()
280 }
281
282 fn log_density(&self, x: &Array1<f64>) -> Option<f64> {
283 let diff = x - &self.mean;
284 let quad_form = diff.dot(&self.precision.dot(&diff));
285
286 let det = scirs2_linalg::det(&self.precision.view(), None).ok()?;
288 if det <= 0.0 {
289 return None;
290 }
291
292 let d = self.mean.len() as f64;
293 let log_norm_const = 0.5 * (det.ln() - d * (2.0 * std::f64::consts::PI).ln());
294
295 Some(log_norm_const - 0.5 * quad_form)
296 }
297}
298
299#[derive(Debug, Clone)]
303pub struct GaussianMixtureGibbs {
304 pub means: Array2<f64>,
306 pub precisions: Vec<Array2<f64>>,
308 pub weights: Array1<f64>,
310 pub data: Array2<f64>,
312 pub assignments: Array1<usize>,
314 pub n_components: usize,
316 pub prior_mean: Array1<f64>,
318 pub prior_precision: Array2<f64>,
319 pub prior_alpha: Array1<f64>, }
321
322impl GaussianMixtureGibbs {
323 pub fn new(
325 data: Array2<f64>,
326 n_components: usize,
327 prior_mean: Array1<f64>,
328 prior_precision: Array2<f64>,
329 prior_alpha: Array1<f64>,
330 ) -> Result<Self> {
331 checkarray_finite(&data, "data")?;
332 check_positive(n_components, "n_components")?;
333 checkarray_finite(&prior_mean, "prior_mean")?;
334 checkarray_finite(&prior_precision, "prior_precision")?;
335 checkarray_finite(&prior_alpha, "prior_alpha")?;
336
337 let (n_samples_, dim) = data.dim();
338
339 if prior_alpha.len() != n_components {
340 return Err(StatsError::DimensionMismatch(format!(
341 "prior_alpha length ({}) must equal n_components ({})",
342 prior_alpha.len(),
343 n_components
344 )));
345 }
346
347 let means = Array2::zeros((n_components, dim));
349 let precisions = vec![Array2::eye(dim); n_components];
350 let weights = Array1::from_elem(n_components, 1.0 / n_components as f64);
351 let assignments = Array1::zeros(n_samples_);
352
353 Ok(Self {
354 means,
355 precisions,
356 weights,
357 data,
358 assignments,
359 n_components,
360 prior_mean,
361 prior_precision,
362 prior_alpha,
363 })
364 }
365
366 pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
368 self.sample_assignments(rng)?;
370
371 self.sample_parameters(rng)?;
373
374 self.sample_weights(rng)?;
376
377 Ok(())
378 }
379
380 fn sample_assignments<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
382 for i in 0..self.data.nrows() {
383 let data_point = self.data.row(i);
384 let mut log_probs = Array1::zeros(self.n_components);
385
386 for k in 0..self.n_components {
388 let mean_k = self.means.row(k);
389 let precision_k = &self.precisions[k];
390
391 let diff = &data_point.to_owned() - &mean_k.to_owned();
392 let quad_form = diff.dot(&precision_k.dot(&diff));
393
394 let det = scirs2_linalg::det(&precision_k.view(), None).map_err(|e| {
396 StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
397 })?;
398
399 if det <= 0.0 {
400 return Err(StatsError::ComputationError(
401 "Precision matrix must be positive definite".to_string(),
402 ));
403 }
404
405 let log_likelihood = 0.5 * det.ln() - 0.5 * quad_form;
406 log_probs[k] = self.weights[k].ln() + log_likelihood;
407 }
408
409 let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
411 let mut probs = log_probs.mapv(|x| (x - max_log_prob).exp());
412 let prob_sum = probs.sum();
413 probs /= prob_sum;
414
415 let u: f64 = rng.random();
417 let mut cumsum = 0.0;
418 let mut selected = 0;
419
420 for (k, &p) in probs.iter().enumerate() {
421 cumsum += p;
422 if u <= cumsum {
423 selected = k;
424 break;
425 }
426 }
427
428 self.assignments[i] = selected;
429 }
430
431 Ok(())
432 }
433
434 fn sample_parameters<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
436 for k in 0..self.n_components {
437 let assigned_indices: Vec<usize> = self
439 .assignments
440 .iter()
441 .enumerate()
442 .filter_map(|(i, &assignment)| if assignment == k { Some(i) } else { None })
443 .collect();
444
445 if assigned_indices.is_empty() {
446 self.sample_from_prior(k, rng)?;
448 } else {
449 self.sample_posterior(k, &assigned_indices, rng)?;
451 }
452 }
453
454 Ok(())
455 }
456
457 fn sample_from_prior<R: Rng + ?Sized>(&mut self, component: usize, rng: &mut R) -> Result<()> {
459 use scirs2_core::random::{Distribution, Normal};
461
462 let dim = self.prior_mean.len();
463 let mut new_mean = Array1::zeros(dim);
464
465 for i in 0..dim {
467 let variance = 1.0 / self.prior_precision[[i, i]];
468 let std = variance.sqrt();
469 let normal = Normal::new(self.prior_mean[i], std).map_err(|e| {
470 StatsError::ComputationError(format!("Failed to create normal: {}", e))
471 })?;
472 new_mean[i] = normal.sample(rng);
473 }
474
475 self.means.row_mut(component).assign(&new_mean);
476
477 self.precisions[component] = self.prior_precision.clone();
479
480 Ok(())
481 }
482
483 fn sample_posterior<R: Rng + ?Sized>(
485 &mut self,
486 component: usize,
487 assigned_indices: &[usize],
488 rng: &mut R,
489 ) -> Result<()> {
490 let n_assigned = assigned_indices.len();
491 let dim = self.prior_mean.len();
492
493 let mut sample_mean = Array1::zeros(dim);
495 for &i in assigned_indices {
496 sample_mean = sample_mean + self.data.row(i);
497 }
498 sample_mean /= n_assigned as f64;
499
500 let posterior_precision = &self.prior_precision + Array2::eye(dim) * n_assigned as f64;
502 let posterior_mean = {
503 let prior_contrib = self.prior_precision.dot(&self.prior_mean);
504 let data_contrib = Array1::from_elem(dim, n_assigned as f64) * &sample_mean;
505 let precision_inv =
506 scirs2_linalg::inv(&posterior_precision.view(), None).map_err(|e| {
507 StatsError::ComputationError(format!("Failed to invert precision: {}", e))
508 })?;
509 precision_inv.dot(&(prior_contrib + data_contrib))
510 };
511
512 use scirs2_core::random::{Distribution, Normal};
514 let mut new_mean = Array1::zeros(dim);
515
516 for i in 0..dim {
517 let variance = 1.0 / posterior_precision[[i, i]];
518 let std = variance.sqrt();
519 let normal = Normal::new(posterior_mean[i], std).map_err(|e| {
520 StatsError::ComputationError(format!("Failed to create normal: {}", e))
521 })?;
522 new_mean[i] = normal.sample(rng);
523 }
524
525 self.means.row_mut(component).assign(&new_mean);
526
527 self.precisions[component] = posterior_precision;
529
530 Ok(())
531 }
532
533 fn sample_weights<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
535 let mut counts = Array1::<f64>::zeros(self.n_components);
537 for &assignment in self.assignments.iter() {
538 counts[assignment] += 1.0;
539 }
540
541 let posterior_alpha = &self.prior_alpha + &counts;
543
544 use scirs2_core::random::{Distribution, Gamma};
546 let mut gamma_samples = Array1::zeros(self.n_components);
547
548 for k in 0..self.n_components {
549 let gamma = Gamma::new(posterior_alpha[k], 1.0).map_err(|e| {
550 StatsError::ComputationError(format!("Failed to create Gamma: {}", e))
551 })?;
552 gamma_samples[k] = gamma.sample(rng);
553 }
554
555 let sum = gamma_samples.sum();
557 self.weights = gamma_samples / sum;
558
559 Ok(())
560 }
561}
562
563pub struct BlockedGibbsSampler<C: ConditionalDistribution> {
567 pub sampler: GibbsSampler<C>,
569 pub blocks: Vec<Vec<usize>>,
571}
572
573impl<C: ConditionalDistribution> BlockedGibbsSampler<C> {
574 pub fn new(conditionals: C, initial: Array1<f64>, blocks: Vec<Vec<usize>>) -> Result<Self> {
576 let sampler = GibbsSampler::new(conditionals, initial)?;
577
578 let dim = sampler.conditionals.dim();
580 let mut all_indices = Vec::new();
581 for block in &blocks {
582 for &idx in block {
583 if idx >= dim {
584 return Err(StatsError::InvalidArgument(format!(
585 "Block index {} exceeds dimension {}",
586 idx, dim
587 )));
588 }
589 all_indices.push(idx);
590 }
591 }
592
593 all_indices.sort_unstable();
594 all_indices.dedup();
595 if all_indices.len() != dim {
596 return Err(StatsError::InvalidArgument(
597 "Blocks must cover all variables exactly once".to_string(),
598 ));
599 }
600
601 Ok(Self { sampler, blocks })
602 }
603
604 pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<Array1<f64>> {
606 for block in &self.blocks {
608 for &var_idx in block {
609 let new_value = self.sampler.conditionals.sample_conditional(
610 &self.sampler.current,
611 var_idx,
612 rng,
613 )?;
614 self.sampler.current[var_idx] = new_value;
615 }
616 }
617
618 self.sampler.n_samples_ += 1;
619 Ok(self.sampler.current.clone())
620 }
621
622 pub fn sample<R: Rng + ?Sized>(
624 &mut self,
625 n_samples_: usize,
626 rng: &mut R,
627 ) -> Result<Array2<f64>> {
628 let dim = self.sampler.current.len();
629 let mut samples = Array2::zeros((n_samples_, dim));
630
631 for i in 0..n_samples_ {
632 let sample = self.step(rng)?;
633 samples.row_mut(i).assign(&sample);
634 }
635
636 Ok(samples)
637 }
638}