1use super::{ProposalDistribution, TargetDistribution};
7use crate::error::{StatsError, StatsResult as Result};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
9use scirs2_core::validation::*;
10use scirs2_core::Rng;
11use statrs::statistics::Statistics;
12use std::sync::Arc;
13
14pub struct MultipleTryMetropolis<T: TargetDistribution, P: ProposalDistribution> {
19 pub target: T,
21 pub proposal: P,
23 pub current: Array1<f64>,
25 pub current_log_density: f64,
27 pub n_tries: usize,
29 pub n_accepted: usize,
31 pub n_steps: usize,
33}
34
35impl<T: TargetDistribution, P: ProposalDistribution> MultipleTryMetropolis<T, P> {
36 pub fn new(target: T, proposal: P, initial: Array1<f64>, ntries: usize) -> Result<Self> {
38 checkarray_finite(&initial, "initial")?;
39 check_positive(ntries, "n_tries")?;
40
41 if initial.len() != target.dim() {
42 return Err(StatsError::DimensionMismatch(format!(
43 "initial dimension ({}) must match target dimension ({})",
44 initial.len(),
45 target.dim()
46 )));
47 }
48
49 let current_log_density = target.log_density(&initial);
50
51 Ok(Self {
52 target,
53 proposal,
54 current: initial,
55 current_log_density,
56 n_tries: ntries,
57 n_accepted: 0,
58 n_steps: 0,
59 })
60 }
61
62 pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<Array1<f64>> {
64 let mut proposals = Vec::with_capacity(self.n_tries);
66 let mut log_densities = Vec::with_capacity(self.n_tries);
67 let mut weights = Vec::with_capacity(self.n_tries);
68
69 for _ in 0..self.n_tries {
70 let proposal = self.proposal.sample(&self.current, rng);
71 let log_density = self.target.log_density(&proposal);
72 let weight = log_density.exp();
73
74 proposals.push(proposal);
75 log_densities.push(log_density);
76 weights.push(weight);
77 }
78
79 let total_weight: f64 = weights.iter().sum();
81 if total_weight <= 0.0 {
82 self.n_steps += 1;
84 return Ok(self.current.clone());
85 }
86
87 let u: f64 = rng.random();
88 let mut cumsum = 0.0;
89 let mut selected_idx = 0;
90
91 for (i, &weight) in weights.iter().enumerate() {
92 cumsum += weight / total_weight;
93 if u <= cumsum {
94 selected_idx = i;
95 break;
96 }
97 }
98
99 let selected_proposal = &proposals[selected_idx];
100 let selected_log_density = log_densities[selected_idx];
101
102 let mut reverse_weights = Vec::with_capacity(self.n_tries);
104 for _ in 0..self.n_tries {
105 let reverse_proposal = self.proposal.sample(selected_proposal, rng);
106 let reverse_log_density = self.target.log_density(&reverse_proposal);
107 let reverse_weight = reverse_log_density.exp();
108 reverse_weights.push(reverse_weight);
109 }
110
111 reverse_weights.push(self.current_log_density.exp());
113
114 let reverse_total_weight: f64 = reverse_weights.iter().sum();
115
116 let log_ratio = selected_log_density - self.current_log_density + reverse_total_weight.ln()
118 - total_weight.ln();
119
120 let accept_u: f64 = rng.random();
122 self.n_steps += 1;
123
124 if accept_u.ln() < log_ratio {
125 self.current = selected_proposal.clone();
126 self.current_log_density = selected_log_density;
127 self.n_accepted += 1;
128 }
129
130 Ok(self.current.clone())
131 }
132
133 pub fn sample<R: Rng + ?Sized>(
135 &mut self,
136 n_samples_: usize,
137 rng: &mut R,
138 ) -> Result<Array2<f64>> {
139 let dim = self.current.len();
140 let mut samples = Array2::zeros((n_samples_, dim));
141
142 for i in 0..n_samples_ {
143 let sample = self.step(rng)?;
144 samples.row_mut(i).assign(&sample);
145 }
146
147 Ok(samples)
148 }
149
150 pub fn acceptance_rate(&self) -> f64 {
152 if self.n_steps == 0 {
153 0.0
154 } else {
155 self.n_accepted as f64 / self.n_steps as f64
156 }
157 }
158}
159
160pub struct ParallelTempering<
165 T: TargetDistribution + Clone + Send,
166 P: ProposalDistribution + Clone + Send,
167> {
168 pub base_target: T,
170 pub proposal: P,
172 pub temperatures: Array1<f64>,
174 pub states: Vec<Array1<f64>>,
176 pub log_densities: Vec<f64>,
178 pub n_chains: usize,
180 pub exchange_freq: usize,
182 pub move_accepted: Vec<usize>,
184 pub exchange_accepted: Vec<usize>,
186 pub move_attempts: Vec<usize>,
188 pub exchange_attempts: Vec<usize>,
190}
191
192impl<T: TargetDistribution + Clone + Send, P: ProposalDistribution + Clone + Send>
193 ParallelTempering<T, P>
194{
195 pub fn new(
197 base_target: T,
198 proposal: P,
199 temperatures: Array1<f64>,
200 initial_states: Vec<Array1<f64>>,
201 exchange_freq: usize,
202 ) -> Result<Self> {
203 check_positive(exchange_freq, "exchange_freq")?;
204
205 let n_chains = temperatures.len();
206 if initial_states.len() != n_chains {
207 return Err(StatsError::DimensionMismatch(format!(
208 "initial_states length ({}) must match temperatures length ({})",
209 initial_states.len(),
210 n_chains
211 )));
212 }
213
214 for &temp in temperatures.iter() {
216 check_positive(temp, "temperature")?;
217 }
218
219 let mut log_densities = Vec::with_capacity(n_chains);
221 for (i, state) in initial_states.iter().enumerate() {
222 checkarray_finite(state, "initial_state")?;
223 let temp = temperatures[i];
224 let log_density = base_target.log_density(state) / temp;
225 log_densities.push(log_density);
226 }
227
228 Ok(Self {
229 base_target,
230 proposal,
231 states: initial_states,
232 log_densities,
233 temperatures,
234 n_chains,
235 exchange_freq,
236 move_accepted: vec![0; n_chains],
237 exchange_accepted: vec![0; n_chains - 1],
238 move_attempts: vec![0; n_chains],
239 exchange_attempts: vec![0; n_chains - 1],
240 })
241 }
242
243 pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
245 for i in 0..self.n_chains {
247 let temp = self.temperatures[i];
248 let current_state = &self.states[i];
249
250 let proposal = self.proposal.sample(current_state, rng);
252 let proposal_log_density = self.base_target.log_density(&proposal) / temp;
253
254 let log_ratio = proposal_log_density - self.log_densities[i]
256 + P::log_ratio(current_state, &proposal);
257
258 self.move_attempts[i] += 1;
259 let u: f64 = rng.random();
260
261 if u.ln() < log_ratio {
262 self.states[i] = proposal;
263 self.log_densities[i] = proposal_log_density;
264 self.move_accepted[i] += 1;
265 }
266 }
267
268 Ok(())
269 }
270
271 pub fn exchange_step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
273 for i in 0..(self.n_chains - 1) {
274 let temp1 = self.temperatures[i];
275 let temp2 = self.temperatures[i + 1];
276
277 let log_density1 = self.log_densities[i];
278 let log_density2 = self.log_densities[i + 1];
279
280 let log_ratio = (log_density1 * temp1 - log_density2 * temp2) / temp2
282 - (log_density1 * temp1 - log_density2 * temp2) / temp1;
283
284 self.exchange_attempts[i] += 1;
285 let u: f64 = rng.random();
286
287 if u.ln() < log_ratio {
288 self.states.swap(i, i + 1);
290
291 let state1_new_log_density = self.base_target.log_density(&self.states[i]) / temp1;
293 let state2_new_log_density =
294 self.base_target.log_density(&self.states[i + 1]) / temp2;
295
296 self.log_densities[i] = state1_new_log_density;
297 self.log_densities[i + 1] = state2_new_log_density;
298
299 self.exchange_accepted[i] += 1;
300 }
301 }
302
303 Ok(())
304 }
305
306 pub fn sample<R: Rng + ?Sized>(
308 &mut self,
309 n_samples_: usize,
310 rng: &mut R,
311 ) -> Result<Array2<f64>> {
312 let dim = self.states[0].len();
313 let mut samples = Array2::zeros((n_samples_, dim));
314
315 for i in 0..n_samples_ {
316 self.step(rng)?;
317
318 if i % self.exchange_freq == 0 {
320 self.exchange_step(rng)?;
321 }
322
323 let coldest_idx = self
325 .temperatures
326 .iter()
327 .enumerate()
328 .min_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
329 .map(|(idx, _)| idx)
330 .unwrap_or(0);
331
332 samples.row_mut(i).assign(&self.states[coldest_idx]);
333 }
334
335 Ok(samples)
336 }
337
338 pub fn move_acceptance_rates(&self) -> Array1<f64> {
340 let mut rates = Array1::zeros(self.n_chains);
341 for i in 0..self.n_chains {
342 if self.move_attempts[i] > 0 {
343 rates[i] = self.move_accepted[i] as f64 / self.move_attempts[i] as f64;
344 }
345 }
346 rates
347 }
348
349 pub fn exchange_acceptance_rates(&self) -> Array1<f64> {
351 let mut rates = Array1::zeros(self.n_chains - 1);
352 for i in 0..(self.n_chains - 1) {
353 if self.exchange_attempts[i] > 0 {
354 rates[i] = self.exchange_accepted[i] as f64 / self.exchange_attempts[i] as f64;
355 }
356 }
357 rates
358 }
359}
360
361pub struct SliceSampler<T: TargetDistribution> {
366 pub target: T,
368 pub current: Array1<f64>,
370 pub current_log_density: f64,
372 pub stepsize: f64,
374 pub max_doublings: usize,
376 pub n_accepted: usize,
378 pub n_proposed: usize,
380}
381
382impl<T: TargetDistribution> SliceSampler<T> {
383 pub fn new(target: T, initial: Array1<f64>, stepsize: f64) -> Result<Self> {
385 checkarray_finite(&initial, "initial")?;
386 check_positive(stepsize, "stepsize")?;
387
388 if initial.len() != target.dim() {
389 return Err(StatsError::DimensionMismatch(format!(
390 "initial dimension ({}) must match target dimension ({})",
391 initial.len(),
392 target.dim()
393 )));
394 }
395
396 let current_log_density = target.log_density(&initial);
397
398 Ok(Self {
399 target,
400 current: initial,
401 current_log_density,
402 stepsize,
403 max_doublings: 20,
404 n_accepted: 0,
405 n_proposed: 0,
406 })
407 }
408
409 pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<Array1<f64>> {
411 let dim = self.current.len();
412 let mut new_state = self.current.clone();
413
414 for d in 0..dim {
416 new_state[d] = self.slice_sample_dimension(&new_state, d, rng)?;
417 }
418
419 self.current = new_state;
420 self.current_log_density = self.target.log_density(&self.current);
421 self.n_proposed += 1;
422 self.n_accepted += 1; Ok(self.current.clone())
425 }
426
427 fn slice_sample_dimension<R: Rng + ?Sized>(
429 &self,
430 state: &Array1<f64>,
431 dimension: usize,
432 rng: &mut R,
433 ) -> Result<f64> {
434 let current_value = state[dimension];
435 let current_log_density = self.target.log_density(state);
436
437 let u: f64 = rng.random();
439 let slice_level = current_log_density + u.ln();
440
441 let mut left = current_value - self.stepsize * rng.random::<f64>();
443 let mut right = left + self.stepsize;
444
445 for _ in 0..self.max_doublings {
447 let mut left_state = state.clone();
448 left_state[dimension] = left;
449 let left_log_density = self.target.log_density(&left_state);
450
451 let mut right_state = state.clone();
452 right_state[dimension] = right;
453 let right_log_density = self.target.log_density(&right_state);
454
455 if left_log_density <= slice_level && right_log_density <= slice_level {
456 break;
457 }
458
459 if rng.random::<bool>() {
460 left = left - (right - left);
461 } else {
462 right = right + (right - left);
463 }
464 }
465
466 loop {
468 let proposal = left + (right - left) * rng.random::<f64>();
469 let mut proposal_state = state.clone();
470 proposal_state[dimension] = proposal;
471 let proposal_log_density = self.target.log_density(&proposal_state);
472
473 if proposal_log_density > slice_level {
474 return Ok(proposal);
475 }
476
477 if proposal < current_value {
479 left = proposal;
480 } else {
481 right = proposal;
482 }
483
484 if (right - left).abs() < 1e-10 {
486 return Ok(current_value);
487 }
488 }
489 }
490
491 pub fn sample<R: Rng + ?Sized>(
493 &mut self,
494 n_samples_: usize,
495 rng: &mut R,
496 ) -> Result<Array2<f64>> {
497 let dim = self.current.len();
498 let mut samples = Array2::zeros((n_samples_, dim));
499
500 for i in 0..n_samples_ {
501 let sample = self.step(rng)?;
502 samples.row_mut(i).assign(&sample);
503 }
504
505 Ok(samples)
506 }
507
508 pub fn acceptance_rate(&self) -> f64 {
510 1.0
511 }
512}
513
514pub struct EnsembleSampler<T: TargetDistribution + Clone + Send + Sync> {
519 pub target: Arc<T>,
521 pub walkers: Array2<f64>,
523 pub log_densities: Array1<f64>,
525 pub n_walkers: usize,
527 pub dim: usize,
529 pub scale: f64,
531 pub n_accepted: Array1<usize>,
533 pub n_proposed: Array1<usize>,
535}
536
537impl<T: TargetDistribution + Clone + Send + Sync> EnsembleSampler<T> {
538 pub fn new(target: T, initialwalkers: Array2<f64>, scale: Option<f64>) -> Result<Self> {
540 checkarray_finite(&initialwalkers, "initial_walkers")?;
541 let (n_walkers, dim) = initialwalkers.dim();
542 let scale = scale.unwrap_or(2.0);
543
544 if n_walkers < 2 * dim {
545 return Err(StatsError::InvalidArgument(format!(
546 "Number of walkers ({}) should be at least 2 * dim ({})",
547 n_walkers,
548 2 * dim
549 )));
550 }
551
552 check_positive(scale, "scale")?;
553
554 let mut log_densities = Array1::zeros(n_walkers);
556 for i in 0..n_walkers {
557 let walker = initialwalkers.row(i);
558 log_densities[i] = target.log_density(&walker.to_owned());
559 }
560
561 Ok(Self {
562 target: Arc::new(target),
563 walkers: initialwalkers,
564 log_densities,
565 n_walkers,
566 dim,
567 scale,
568 n_accepted: Array1::zeros(n_walkers),
569 n_proposed: Array1::zeros(n_walkers),
570 })
571 }
572
573 pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
575 let n_half = self.n_walkers / 2;
577
578 self.update_group(0, n_half, n_half, self.n_walkers, rng)?;
580
581 self.update_group(n_half, self.n_walkers, 0, n_half, rng)?;
583
584 Ok(())
585 }
586
587 fn update_group<R: Rng + ?Sized>(
589 &mut self,
590 start: usize,
591 end: usize,
592 comp_start: usize,
593 comp_end: usize,
594 rng: &mut R,
595 ) -> Result<()> {
596 for i in start..end {
597 let compsize = comp_end - comp_start;
599 let j = comp_start + rng.random_range(0..compsize);
600
601 let z = ((self.scale - 1.0) * rng.random::<f64>() + 1.0).powf(2.0) / self.scale;
603
604 let walker_i = self.walkers.row(i);
606 let walker_j = self.walkers.row(j);
607 let proposal = &walker_j.to_owned() + z * (&walker_i.to_owned() - &walker_j.to_owned());
608
609 let proposal_log_density = self.target.log_density(&proposal);
611
612 let log_ratio =
614 (self.dim as f64 - 1.0) * z.ln() + proposal_log_density - self.log_densities[i];
615
616 let u: f64 = rng.random();
618 self.n_proposed[i] += 1;
619
620 if u.ln() < log_ratio {
621 self.walkers.row_mut(i).assign(&proposal);
622 self.log_densities[i] = proposal_log_density;
623 self.n_accepted[i] += 1;
624 }
625 }
626
627 Ok(())
628 }
629
630 pub fn sample<R: Rng + ?Sized>(
632 &mut self,
633 n_samples_: usize,
634 rng: &mut R,
635 ) -> Result<Array2<f64>> {
636 let total_samples = n_samples_ * self.n_walkers;
637 let mut samples = Array2::zeros((total_samples, self.dim));
638
639 for i in 0..n_samples_ {
640 self.step(rng)?;
641
642 for j in 0..self.n_walkers {
644 let sample_idx = i * self.n_walkers + j;
645 samples.row_mut(sample_idx).assign(&self.walkers.row(j));
646 }
647 }
648
649 Ok(samples)
650 }
651
652 pub fn acceptance_rates(&self) -> Array1<f64> {
654 let mut rates = Array1::zeros(self.n_walkers);
655 for i in 0..self.n_walkers {
656 if self.n_proposed[i] > 0 {
657 rates[i] = self.n_accepted[i] as f64 / self.n_proposed[i] as f64;
658 }
659 }
660 rates
661 }
662
663 pub fn get_walkers(&self) -> &Array2<f64> {
665 &self.walkers
666 }
667
668 pub fn chain_statistics(&self, samples: &Array2<f64>) -> Result<ChainStatistics> {
670 let (n_samples_, dim) = samples.dim();
671
672 let means = samples.mean_axis(Axis(0)).expect("Operation failed");
674
675 let mut variances = Array1::zeros(dim);
677 for j in 0..dim {
678 let col = samples.column(j);
679 let mean_j = means[j];
680 let var_j = col.mapv(|x| (x - mean_j).powi(2)).mean();
681 variances[j] = var_j;
682 }
683
684 let mut autocorr_times = Array1::zeros(dim);
686 for j in 0..dim {
687 autocorr_times[j] = self.estimate_autocorr_time(&samples.column(j))?;
688 }
689
690 Ok(ChainStatistics {
691 means,
692 variances,
693 autocorr_times,
694 n_samples_,
695 dim,
696 })
697 }
698
699 fn estimate_autocorr_time(&self, chain: &ArrayView1<f64>) -> Result<f64> {
701 let n = chain.len();
702 if n < 4 {
703 return Ok(1.0);
704 }
705
706 use scirs2_core::ndarray::ArrayStatCompat;
707 let mean = chain.mean_or(0.0);
708 let variance = chain.mapv(|x| (x - mean).powi(2)).mean_or(1.0);
709
710 if variance <= 0.0 {
711 return Ok(1.0);
712 }
713
714 let max_lag = (n / 4).min(200);
716 let mut autocorr = Array1::zeros(max_lag);
717
718 for lag in 0..max_lag {
719 let mut sum = 0.0;
720 let mut count = 0;
721
722 for i in 0..(n - lag) {
723 sum += (chain[i] - mean) * (chain[i + lag] - mean);
724 count += 1;
725 }
726
727 if count > 0 {
728 autocorr[lag] = sum / (count as f64 * variance);
729 }
730 }
731
732 let threshold = std::f64::consts::E.recip();
734 for lag in 1..max_lag {
735 if autocorr[lag] < threshold || autocorr[lag] < 0.0 {
736 return Ok(lag as f64);
737 }
738 }
739
740 Ok(max_lag as f64)
741 }
742}
743
744#[derive(Debug, Clone)]
746pub struct ChainStatistics {
747 pub means: Array1<f64>,
749 pub variances: Array1<f64>,
751 pub autocorr_times: Array1<f64>,
753 pub n_samples_: usize,
755 pub dim: usize,
757}
758
759impl ChainStatistics {
760 pub fn effective_samplesizes(&self) -> Array1<f64> {
762 self.autocorr_times.mapv(|tau| {
763 if tau > 0.0 {
764 self.n_samples_ as f64 / (2.0 * tau)
765 } else {
766 self.n_samples_ as f64
767 }
768 })
769 }
770
771 pub fn is_converged(&self, threshold: f64) -> bool {
773 let max_autocorr = self.autocorr_times.iter().cloned().fold(0.0f64, f64::max);
775 let min_eff_samples = self.n_samples_ as f64 / (2.0 * max_autocorr);
776
777 min_eff_samples > threshold
778 }
779}