Skip to main content

scirs2_stats/mcmc/
advanced.rs

1//! Advanced MCMC methods
2//!
3//! This module implements sophisticated MCMC algorithms including multiple-try Metropolis,
4//! parallel tempering, slice sampling, and ensemble methods.
5
6use super::{ProposalDistribution, TargetDistribution};
7use crate::error::{StatsError, StatsResult as Result};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
9use scirs2_core::validation::*;
10use scirs2_core::Rng;
11use statrs::statistics::Statistics;
12use std::sync::Arc;
13
14/// Multiple-try Metropolis sampler
15///
16/// Generates multiple proposals at each step and selects one using weighted sampling,
17/// which can lead to better acceptance rates and mixing.
18pub struct MultipleTryMetropolis<T: TargetDistribution, P: ProposalDistribution> {
19    /// Target distribution
20    pub target: T,
21    /// Proposal distribution
22    pub proposal: P,
23    /// Current state
24    pub current: Array1<f64>,
25    /// Current log density
26    pub current_log_density: f64,
27    /// Number of proposal trials per step
28    pub n_tries: usize,
29    /// Number of accepted proposals
30    pub n_accepted: usize,
31    /// Total number of steps
32    pub n_steps: usize,
33}
34
35impl<T: TargetDistribution, P: ProposalDistribution> MultipleTryMetropolis<T, P> {
36    /// Create a new multiple-try Metropolis sampler
37    pub fn new(target: T, proposal: P, initial: Array1<f64>, ntries: usize) -> Result<Self> {
38        checkarray_finite(&initial, "initial")?;
39        check_positive(ntries, "n_tries")?;
40
41        if initial.len() != target.dim() {
42            return Err(StatsError::DimensionMismatch(format!(
43                "initial dimension ({}) must match target dimension ({})",
44                initial.len(),
45                target.dim()
46            )));
47        }
48
49        let current_log_density = target.log_density(&initial);
50
51        Ok(Self {
52            target,
53            proposal,
54            current: initial,
55            current_log_density,
56            n_tries: ntries,
57            n_accepted: 0,
58            n_steps: 0,
59        })
60    }
61
62    /// Perform one step of multiple-try Metropolis
63    pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<Array1<f64>> {
64        // Generate multiple proposals
65        let mut proposals = Vec::with_capacity(self.n_tries);
66        let mut log_densities = Vec::with_capacity(self.n_tries);
67        let mut weights = Vec::with_capacity(self.n_tries);
68
69        for _ in 0..self.n_tries {
70            let proposal = self.proposal.sample(&self.current, rng);
71            let log_density = self.target.log_density(&proposal);
72            let weight = log_density.exp();
73
74            proposals.push(proposal);
75            log_densities.push(log_density);
76            weights.push(weight);
77        }
78
79        // Select proposal using weighted sampling
80        let total_weight: f64 = weights.iter().sum();
81        if total_weight <= 0.0 {
82            // All proposals have zero weight, reject
83            self.n_steps += 1;
84            return Ok(self.current.clone());
85        }
86
87        let u: f64 = rng.random();
88        let mut cumsum = 0.0;
89        let mut selected_idx = 0;
90
91        for (i, &weight) in weights.iter().enumerate() {
92            cumsum += weight / total_weight;
93            if u <= cumsum {
94                selected_idx = i;
95                break;
96            }
97        }
98
99        let selected_proposal = &proposals[selected_idx];
100        let selected_log_density = log_densities[selected_idx];
101
102        // Compute reverse proposals from selected proposal
103        let mut reverse_weights = Vec::with_capacity(self.n_tries);
104        for _ in 0..self.n_tries {
105            let reverse_proposal = self.proposal.sample(selected_proposal, rng);
106            let reverse_log_density = self.target.log_density(&reverse_proposal);
107            let reverse_weight = reverse_log_density.exp();
108            reverse_weights.push(reverse_weight);
109        }
110
111        // Include current state in reverse proposals
112        reverse_weights.push(self.current_log_density.exp());
113
114        let reverse_total_weight: f64 = reverse_weights.iter().sum();
115
116        // Compute acceptance ratio
117        let log_ratio = selected_log_density - self.current_log_density + reverse_total_weight.ln()
118            - total_weight.ln();
119
120        // Accept or reject
121        let accept_u: f64 = rng.random();
122        self.n_steps += 1;
123
124        if accept_u.ln() < log_ratio {
125            self.current = selected_proposal.clone();
126            self.current_log_density = selected_log_density;
127            self.n_accepted += 1;
128        }
129
130        Ok(self.current.clone())
131    }
132
133    /// Sample multiple states
134    pub fn sample<R: Rng + ?Sized>(
135        &mut self,
136        n_samples_: usize,
137        rng: &mut R,
138    ) -> Result<Array2<f64>> {
139        let dim = self.current.len();
140        let mut samples = Array2::zeros((n_samples_, dim));
141
142        for i in 0..n_samples_ {
143            let sample = self.step(rng)?;
144            samples.row_mut(i).assign(&sample);
145        }
146
147        Ok(samples)
148    }
149
150    /// Get acceptance rate
151    pub fn acceptance_rate(&self) -> f64 {
152        if self.n_steps == 0 {
153            0.0
154        } else {
155            self.n_accepted as f64 / self.n_steps as f64
156        }
157    }
158}
159
160/// Parallel Tempering (Replica Exchange) sampler
161///
162/// Runs multiple chains at different temperatures in parallel and exchanges
163/// states between chains to improve mixing.
164pub struct ParallelTempering<
165    T: TargetDistribution + Clone + Send,
166    P: ProposalDistribution + Clone + Send,
167> {
168    /// Base target distribution
169    pub base_target: T,
170    /// Proposal distribution
171    pub proposal: P,
172    /// Temperature schedule
173    pub temperatures: Array1<f64>,
174    /// Current states for each chain
175    pub states: Vec<Array1<f64>>,
176    /// Current log densities for each chain
177    pub log_densities: Vec<f64>,
178    /// Number of chains
179    pub n_chains: usize,
180    /// Exchange attempt frequency
181    pub exchange_freq: usize,
182    /// Acceptance counters for moves
183    pub move_accepted: Vec<usize>,
184    /// Acceptance counters for exchanges
185    pub exchange_accepted: Vec<usize>,
186    /// Total move attempts
187    pub move_attempts: Vec<usize>,
188    /// Total exchange attempts
189    pub exchange_attempts: Vec<usize>,
190}
191
192impl<T: TargetDistribution + Clone + Send, P: ProposalDistribution + Clone + Send>
193    ParallelTempering<T, P>
194{
195    /// Create a new parallel tempering sampler
196    pub fn new(
197        base_target: T,
198        proposal: P,
199        temperatures: Array1<f64>,
200        initial_states: Vec<Array1<f64>>,
201        exchange_freq: usize,
202    ) -> Result<Self> {
203        check_positive(exchange_freq, "exchange_freq")?;
204
205        let n_chains = temperatures.len();
206        if initial_states.len() != n_chains {
207            return Err(StatsError::DimensionMismatch(format!(
208                "initial_states length ({}) must match temperatures length ({})",
209                initial_states.len(),
210                n_chains
211            )));
212        }
213
214        // Check temperatures are positive and sorted
215        for &temp in temperatures.iter() {
216            check_positive(temp, "temperature")?;
217        }
218
219        // Compute initial log densities
220        let mut log_densities = Vec::with_capacity(n_chains);
221        for (i, state) in initial_states.iter().enumerate() {
222            checkarray_finite(state, "initial_state")?;
223            let temp = temperatures[i];
224            let log_density = base_target.log_density(state) / temp;
225            log_densities.push(log_density);
226        }
227
228        Ok(Self {
229            base_target,
230            proposal,
231            states: initial_states,
232            log_densities,
233            temperatures,
234            n_chains,
235            exchange_freq,
236            move_accepted: vec![0; n_chains],
237            exchange_accepted: vec![0; n_chains - 1],
238            move_attempts: vec![0; n_chains],
239            exchange_attempts: vec![0; n_chains - 1],
240        })
241    }
242
243    /// Perform one step for all chains
244    pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
245        // Metropolis steps for all chains
246        for i in 0..self.n_chains {
247            let temp = self.temperatures[i];
248            let current_state = &self.states[i];
249
250            // Generate proposal
251            let proposal = self.proposal.sample(current_state, rng);
252            let proposal_log_density = self.base_target.log_density(&proposal) / temp;
253
254            // Accept or reject
255            let log_ratio = proposal_log_density - self.log_densities[i]
256                + P::log_ratio(current_state, &proposal);
257
258            self.move_attempts[i] += 1;
259            let u: f64 = rng.random();
260
261            if u.ln() < log_ratio {
262                self.states[i] = proposal;
263                self.log_densities[i] = proposal_log_density;
264                self.move_accepted[i] += 1;
265            }
266        }
267
268        Ok(())
269    }
270
271    /// Attempt exchanges between adjacent chains
272    pub fn exchange_step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
273        for i in 0..(self.n_chains - 1) {
274            let temp1 = self.temperatures[i];
275            let temp2 = self.temperatures[i + 1];
276
277            let log_density1 = self.log_densities[i];
278            let log_density2 = self.log_densities[i + 1];
279
280            // Compute exchange probability
281            let log_ratio = (log_density1 * temp1 - log_density2 * temp2) / temp2
282                - (log_density1 * temp1 - log_density2 * temp2) / temp1;
283
284            self.exchange_attempts[i] += 1;
285            let u: f64 = rng.random();
286
287            if u.ln() < log_ratio {
288                // Exchange states
289                self.states.swap(i, i + 1);
290
291                // Update log densities for new temperatures
292                let state1_new_log_density = self.base_target.log_density(&self.states[i]) / temp1;
293                let state2_new_log_density =
294                    self.base_target.log_density(&self.states[i + 1]) / temp2;
295
296                self.log_densities[i] = state1_new_log_density;
297                self.log_densities[i + 1] = state2_new_log_density;
298
299                self.exchange_accepted[i] += 1;
300            }
301        }
302
303        Ok(())
304    }
305
306    /// Run the parallel tempering sampler
307    pub fn sample<R: Rng + ?Sized>(
308        &mut self,
309        n_samples_: usize,
310        rng: &mut R,
311    ) -> Result<Array2<f64>> {
312        let dim = self.states[0].len();
313        let mut samples = Array2::zeros((n_samples_, dim));
314
315        for i in 0..n_samples_ {
316            self.step(rng)?;
317
318            // Attempt exchanges periodically
319            if i % self.exchange_freq == 0 {
320                self.exchange_step(rng)?;
321            }
322
323            // Store sample from coldest chain (temperature = 1.0)
324            let coldest_idx = self
325                .temperatures
326                .iter()
327                .enumerate()
328                .min_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
329                .map(|(idx, _)| idx)
330                .unwrap_or(0);
331
332            samples.row_mut(i).assign(&self.states[coldest_idx]);
333        }
334
335        Ok(samples)
336    }
337
338    /// Get acceptance rates for moves
339    pub fn move_acceptance_rates(&self) -> Array1<f64> {
340        let mut rates = Array1::zeros(self.n_chains);
341        for i in 0..self.n_chains {
342            if self.move_attempts[i] > 0 {
343                rates[i] = self.move_accepted[i] as f64 / self.move_attempts[i] as f64;
344            }
345        }
346        rates
347    }
348
349    /// Get acceptance rates for exchanges
350    pub fn exchange_acceptance_rates(&self) -> Array1<f64> {
351        let mut rates = Array1::zeros(self.n_chains - 1);
352        for i in 0..(self.n_chains - 1) {
353            if self.exchange_attempts[i] > 0 {
354                rates[i] = self.exchange_accepted[i] as f64 / self.exchange_attempts[i] as f64;
355            }
356        }
357        rates
358    }
359}
360
361/// Slice sampler
362///
363/// Uses auxiliary variables to transform the sampling problem into uniform sampling
364/// over the area under the probability density function.
365pub struct SliceSampler<T: TargetDistribution> {
366    /// Target distribution
367    pub target: T,
368    /// Current state
369    pub current: Array1<f64>,
370    /// Current log density
371    pub current_log_density: f64,
372    /// Step size for finding interval
373    pub stepsize: f64,
374    /// Maximum number of doublings for interval finding
375    pub max_doublings: usize,
376    /// Number of accepted proposals
377    pub n_accepted: usize,
378    /// Total number of proposals
379    pub n_proposed: usize,
380}
381
382impl<T: TargetDistribution> SliceSampler<T> {
383    /// Create a new slice sampler
384    pub fn new(target: T, initial: Array1<f64>, stepsize: f64) -> Result<Self> {
385        checkarray_finite(&initial, "initial")?;
386        check_positive(stepsize, "stepsize")?;
387
388        if initial.len() != target.dim() {
389            return Err(StatsError::DimensionMismatch(format!(
390                "initial dimension ({}) must match target dimension ({})",
391                initial.len(),
392                target.dim()
393            )));
394        }
395
396        let current_log_density = target.log_density(&initial);
397
398        Ok(Self {
399            target,
400            current: initial,
401            current_log_density,
402            stepsize,
403            max_doublings: 20,
404            n_accepted: 0,
405            n_proposed: 0,
406        })
407    }
408
409    /// Perform one step of slice sampling
410    pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<Array1<f64>> {
411        let dim = self.current.len();
412        let mut new_state = self.current.clone();
413
414        // Sample each dimension sequentially
415        for d in 0..dim {
416            new_state[d] = self.slice_sample_dimension(&new_state, d, rng)?;
417        }
418
419        self.current = new_state;
420        self.current_log_density = self.target.log_density(&self.current);
421        self.n_proposed += 1;
422        self.n_accepted += 1; // Slice sampling always accepts
423
424        Ok(self.current.clone())
425    }
426
427    /// Sample a single dimension using slice sampling
428    fn slice_sample_dimension<R: Rng + ?Sized>(
429        &self,
430        state: &Array1<f64>,
431        dimension: usize,
432        rng: &mut R,
433    ) -> Result<f64> {
434        let current_value = state[dimension];
435        let current_log_density = self.target.log_density(state);
436
437        // Sample auxiliary variable (slice level)
438        let u: f64 = rng.random();
439        let slice_level = current_log_density + u.ln();
440
441        // Find initial interval
442        let mut left = current_value - self.stepsize * rng.random::<f64>();
443        let mut right = left + self.stepsize;
444
445        // Expand interval using doubling procedure
446        for _ in 0..self.max_doublings {
447            let mut left_state = state.clone();
448            left_state[dimension] = left;
449            let left_log_density = self.target.log_density(&left_state);
450
451            let mut right_state = state.clone();
452            right_state[dimension] = right;
453            let right_log_density = self.target.log_density(&right_state);
454
455            if left_log_density <= slice_level && right_log_density <= slice_level {
456                break;
457            }
458
459            if rng.random::<bool>() {
460                left = left - (right - left);
461            } else {
462                right = right + (right - left);
463            }
464        }
465
466        // Sample from interval using shrinkage
467        loop {
468            let proposal = left + (right - left) * rng.random::<f64>();
469            let mut proposal_state = state.clone();
470            proposal_state[dimension] = proposal;
471            let proposal_log_density = self.target.log_density(&proposal_state);
472
473            if proposal_log_density > slice_level {
474                return Ok(proposal);
475            }
476
477            // Shrink interval
478            if proposal < current_value {
479                left = proposal;
480            } else {
481                right = proposal;
482            }
483
484            // Prevent infinite loop
485            if (right - left).abs() < 1e-10 {
486                return Ok(current_value);
487            }
488        }
489    }
490
491    /// Sample multiple states
492    pub fn sample<R: Rng + ?Sized>(
493        &mut self,
494        n_samples_: usize,
495        rng: &mut R,
496    ) -> Result<Array2<f64>> {
497        let dim = self.current.len();
498        let mut samples = Array2::zeros((n_samples_, dim));
499
500        for i in 0..n_samples_ {
501            let sample = self.step(rng)?;
502            samples.row_mut(i).assign(&sample);
503        }
504
505        Ok(samples)
506    }
507
508    /// Get acceptance rate (always 1.0 for slice sampling)
509    pub fn acceptance_rate(&self) -> f64 {
510        1.0
511    }
512}
513
514/// Ensemble sampler (Affine Invariant MCMC)
515///
516/// Uses an ensemble of walkers that evolve simultaneously, with proposals
517/// based on the current positions of other walkers.
518pub struct EnsembleSampler<T: TargetDistribution + Clone + Send + Sync> {
519    /// Target distribution
520    pub target: Arc<T>,
521    /// Walker positions
522    pub walkers: Array2<f64>,
523    /// Log densities for each walker
524    pub log_densities: Array1<f64>,
525    /// Number of walkers
526    pub n_walkers: usize,
527    /// Dimensionality
528    pub dim: usize,
529    /// Scale parameter for proposals
530    pub scale: f64,
531    /// Acceptance counters
532    pub n_accepted: Array1<usize>,
533    /// Total proposals
534    pub n_proposed: Array1<usize>,
535}
536
537impl<T: TargetDistribution + Clone + Send + Sync> EnsembleSampler<T> {
538    /// Create a new ensemble sampler
539    pub fn new(target: T, initialwalkers: Array2<f64>, scale: Option<f64>) -> Result<Self> {
540        checkarray_finite(&initialwalkers, "initial_walkers")?;
541        let (n_walkers, dim) = initialwalkers.dim();
542        let scale = scale.unwrap_or(2.0);
543
544        if n_walkers < 2 * dim {
545            return Err(StatsError::InvalidArgument(format!(
546                "Number of walkers ({}) should be at least 2 * dim ({})",
547                n_walkers,
548                2 * dim
549            )));
550        }
551
552        check_positive(scale, "scale")?;
553
554        // Compute initial log densities
555        let mut log_densities = Array1::zeros(n_walkers);
556        for i in 0..n_walkers {
557            let walker = initialwalkers.row(i);
558            log_densities[i] = target.log_density(&walker.to_owned());
559        }
560
561        Ok(Self {
562            target: Arc::new(target),
563            walkers: initialwalkers,
564            log_densities,
565            n_walkers,
566            dim,
567            scale,
568            n_accepted: Array1::zeros(n_walkers),
569            n_proposed: Array1::zeros(n_walkers),
570        })
571    }
572
573    /// Perform one step for all walkers
574    pub fn step<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<()> {
575        // Split walkers into two groups
576        let n_half = self.n_walkers / 2;
577
578        // Update first half using second half as complementary ensemble
579        self.update_group(0, n_half, n_half, self.n_walkers, rng)?;
580
581        // Update second half using first half as complementary ensemble
582        self.update_group(n_half, self.n_walkers, 0, n_half, rng)?;
583
584        Ok(())
585    }
586
587    /// Update a group of walkers
588    fn update_group<R: Rng + ?Sized>(
589        &mut self,
590        start: usize,
591        end: usize,
592        comp_start: usize,
593        comp_end: usize,
594        rng: &mut R,
595    ) -> Result<()> {
596        for i in start..end {
597            // Select random walker from complementary ensemble
598            let compsize = comp_end - comp_start;
599            let j = comp_start + rng.random_range(0..compsize);
600
601            // Generate stretch parameter
602            let z = ((self.scale - 1.0) * rng.random::<f64>() + 1.0).powf(2.0) / self.scale;
603
604            // Compute proposal
605            let walker_i = self.walkers.row(i);
606            let walker_j = self.walkers.row(j);
607            let proposal = &walker_j.to_owned() + z * (&walker_i.to_owned() - &walker_j.to_owned());
608
609            // Compute log density
610            let proposal_log_density = self.target.log_density(&proposal);
611
612            // Compute acceptance probability
613            let log_ratio =
614                (self.dim as f64 - 1.0) * z.ln() + proposal_log_density - self.log_densities[i];
615
616            // Accept or reject
617            let u: f64 = rng.random();
618            self.n_proposed[i] += 1;
619
620            if u.ln() < log_ratio {
621                self.walkers.row_mut(i).assign(&proposal);
622                self.log_densities[i] = proposal_log_density;
623                self.n_accepted[i] += 1;
624            }
625        }
626
627        Ok(())
628    }
629
630    /// Sample multiple steps
631    pub fn sample<R: Rng + ?Sized>(
632        &mut self,
633        n_samples_: usize,
634        rng: &mut R,
635    ) -> Result<Array2<f64>> {
636        let total_samples = n_samples_ * self.n_walkers;
637        let mut samples = Array2::zeros((total_samples, self.dim));
638
639        for i in 0..n_samples_ {
640            self.step(rng)?;
641
642            // Store all walker positions
643            for j in 0..self.n_walkers {
644                let sample_idx = i * self.n_walkers + j;
645                samples.row_mut(sample_idx).assign(&self.walkers.row(j));
646            }
647        }
648
649        Ok(samples)
650    }
651
652    /// Get acceptance rates for all walkers
653    pub fn acceptance_rates(&self) -> Array1<f64> {
654        let mut rates = Array1::zeros(self.n_walkers);
655        for i in 0..self.n_walkers {
656            if self.n_proposed[i] > 0 {
657                rates[i] = self.n_accepted[i] as f64 / self.n_proposed[i] as f64;
658            }
659        }
660        rates
661    }
662
663    /// Get current walker positions
664    pub fn get_walkers(&self) -> &Array2<f64> {
665        &self.walkers
666    }
667
668    /// Compute chain statistics (mean, autocorrelation time, etc.)
669    pub fn chain_statistics(&self, samples: &Array2<f64>) -> Result<ChainStatistics> {
670        let (n_samples_, dim) = samples.dim();
671
672        // Compute means
673        let means = samples.mean_axis(Axis(0)).expect("Operation failed");
674
675        // Compute variances
676        let mut variances = Array1::zeros(dim);
677        for j in 0..dim {
678            let col = samples.column(j);
679            let mean_j = means[j];
680            let var_j = col.mapv(|x| (x - mean_j).powi(2)).mean();
681            variances[j] = var_j;
682        }
683
684        // Estimate autocorrelation times (simplified)
685        let mut autocorr_times = Array1::zeros(dim);
686        for j in 0..dim {
687            autocorr_times[j] = self.estimate_autocorr_time(&samples.column(j))?;
688        }
689
690        Ok(ChainStatistics {
691            means,
692            variances,
693            autocorr_times,
694            n_samples_,
695            dim,
696        })
697    }
698
699    /// Estimate autocorrelation time for a single chain
700    fn estimate_autocorr_time(&self, chain: &ArrayView1<f64>) -> Result<f64> {
701        let n = chain.len();
702        if n < 4 {
703            return Ok(1.0);
704        }
705
706        use scirs2_core::ndarray::ArrayStatCompat;
707        let mean = chain.mean_or(0.0);
708        let variance = chain.mapv(|x| (x - mean).powi(2)).mean_or(1.0);
709
710        if variance <= 0.0 {
711            return Ok(1.0);
712        }
713
714        // Compute autocorrelation function
715        let max_lag = (n / 4).min(200);
716        let mut autocorr = Array1::zeros(max_lag);
717
718        for lag in 0..max_lag {
719            let mut sum = 0.0;
720            let mut count = 0;
721
722            for i in 0..(n - lag) {
723                sum += (chain[i] - mean) * (chain[i + lag] - mean);
724                count += 1;
725            }
726
727            if count > 0 {
728                autocorr[lag] = sum / (count as f64 * variance);
729            }
730        }
731
732        // Find first negative value or when autocorr drops below e^(-1)
733        let threshold = std::f64::consts::E.recip();
734        for lag in 1..max_lag {
735            if autocorr[lag] < threshold || autocorr[lag] < 0.0 {
736                return Ok(lag as f64);
737            }
738        }
739
740        Ok(max_lag as f64)
741    }
742}
743
744/// Chain statistics from ensemble sampling
745#[derive(Debug, Clone)]
746pub struct ChainStatistics {
747    /// Mean values for each dimension
748    pub means: Array1<f64>,
749    /// Variances for each dimension
750    pub variances: Array1<f64>,
751    /// Autocorrelation times for each dimension
752    pub autocorr_times: Array1<f64>,
753    /// Number of samples
754    pub n_samples_: usize,
755    /// Dimensionality
756    pub dim: usize,
757}
758
759impl ChainStatistics {
760    /// Get effective sample sizes
761    pub fn effective_samplesizes(&self) -> Array1<f64> {
762        self.autocorr_times.mapv(|tau| {
763            if tau > 0.0 {
764                self.n_samples_ as f64 / (2.0 * tau)
765            } else {
766                self.n_samples_ as f64
767            }
768        })
769    }
770
771    /// Check if chains have converged (simplified Gelman-Rubin diagnostic)
772    pub fn is_converged(&self, threshold: f64) -> bool {
773        // For ensemble methods, check if autocorrelation times are reasonable
774        let max_autocorr = self.autocorr_times.iter().cloned().fold(0.0f64, f64::max);
775        let min_eff_samples = self.n_samples_ as f64 / (2.0 * max_autocorr);
776
777        min_eff_samples > threshold
778    }
779}