Skip to main content

scirs2_stats/bayesian/
advanced_mcmc.rs

1//! Advanced MCMC sampling methods
2//!
3//! This module provides sophisticated Markov Chain Monte Carlo methods including
4//! Hamiltonian Monte Carlo, No-U-Turn Sampler (NUTS), and adaptive algorithms.
5
6use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::{random::prelude::*, validation::*};
9
10/// Trait for defining log probability density functions
11pub trait LogDensity {
12    /// Compute log probability density at given point
13    fn log_density(&self, theta: &ArrayView1<f64>) -> Result<f64>;
14
15    /// Compute gradient of log density (if available)
16    fn gradient(&self, theta: &ArrayView1<f64>) -> Result<Option<Array1<f64>>>;
17
18    /// Number of dimensions
19    fn ndim(&self) -> usize;
20}
21
22/// Hamiltonian Monte Carlo (HMC) sampler
23///
24/// HMC uses gradient information to make efficient proposals in high-dimensional spaces.
25/// It simulates Hamiltonian dynamics with momentum variables to explore the posterior.
26#[derive(Debug, Clone)]
27pub struct HamiltonianMonteCarlo {
28    /// Step size for leapfrog integration
29    pub stepsize: f64,
30    /// Number of leapfrog steps
31    pub n_steps: usize,
32    /// Mass matrix (inverse covariance of momentum)
33    pub mass_matrix: Option<Array2<f64>>,
34    /// Random number generator seed
35    pub seed: Option<u64>,
36    /// Whether to adapt step size
37    pub adapt_stepsize: bool,
38    /// Target acceptance rate for adaptation
39    pub target_acceptance: f64,
40    /// Adaptation window size
41    pub adaptation_window: usize,
42}
43
44impl HamiltonianMonteCarlo {
45    /// Create a new HMC sampler
46    pub fn new(stepsize: f64, n_steps: usize) -> Result<Self> {
47        check_positive(stepsize, "stepsize")?;
48        check_positive(n_steps, "n_steps")?;
49
50        Ok(Self {
51            stepsize,
52            n_steps,
53            mass_matrix: None,
54            seed: None,
55            adapt_stepsize: true,
56            target_acceptance: 0.8,
57            adaptation_window: 1000,
58        })
59    }
60
61    /// Set mass matrix
62    pub fn with_mass_matrix(mut self, mass_matrix: Array2<f64>) -> Result<Self> {
63        checkarray_finite(&mass_matrix, "mass_matrix")?;
64        self.mass_matrix = Some(mass_matrix);
65        Ok(self)
66    }
67
68    /// Set random seed
69    pub fn with_seed(mut self, seed: u64) -> Self {
70        self.seed = Some(seed);
71        self
72    }
73
74    /// Disable step size adaptation
75    pub fn without_adaptation(mut self) -> Self {
76        self.adapt_stepsize = false;
77        self
78    }
79
80    /// Sample from the target distribution
81    pub fn sample<D: LogDensity>(
82        &self,
83        target: &D,
84        n_samples_: usize,
85        initial_state: ArrayView1<f64>,
86    ) -> Result<HMCResult> {
87        check_positive(n_samples_, "n_samples_")?;
88        checkarray_finite(&initial_state, "initial_state")?;
89
90        if initial_state.len() != target.ndim() {
91            return Err(StatsError::DimensionMismatch(format!(
92                "initial_state length ({}) must match target dimension ({})",
93                initial_state.len(),
94                target.ndim()
95            )));
96        }
97
98        let mut rng = match self.seed {
99            Some(seed) => StdRng::seed_from_u64(seed),
100            None => StdRng::from_rng(&mut thread_rng()),
101        };
102
103        let ndim = target.ndim();
104        let mut samples = Array2::zeros((n_samples_, ndim));
105        let mut log_probs = Array1::zeros(n_samples_);
106        let mut accepted = Array1::from_elem(n_samples_, false);
107
108        // Initialize mass matrix if not provided
109        let mass_matrix = self
110            .mass_matrix
111            .clone()
112            .unwrap_or_else(|| Array2::eye(ndim));
113        let mass_matrix_inv = self.invert_mass_matrix(&mass_matrix)?;
114
115        let mut current_state = initial_state.to_owned();
116        let mut current_log_prob = target.log_density(&current_state.view())?;
117
118        let mut stepsize = self.stepsize;
119        let mut n_accepted = 0;
120
121        for i in 0..n_samples_ {
122            // Generate momentum
123            let momentum = self.samplemomentum(&mass_matrix, &mut rng)?;
124
125            // Hamiltonian dynamics
126            let (proposed_state, proposedmomentum, proposed_log_prob) = self.leapfrog_integration(
127                &current_state,
128                &momentum,
129                target,
130                &mass_matrix_inv,
131                stepsize,
132            )?;
133
134            // Metropolis acceptance
135            let current_energy =
136                -current_log_prob + 0.5 * self.kinetic_energy(&momentum, &mass_matrix_inv)?;
137            let proposed_energy = -proposed_log_prob
138                + 0.5 * self.kinetic_energy(&proposedmomentum, &mass_matrix_inv)?;
139
140            let accept_prob = (-proposed_energy + current_energy).exp().min(1.0);
141            let accept = rng.random::<f64>() < accept_prob;
142
143            if accept {
144                current_state = proposed_state;
145                current_log_prob = proposed_log_prob;
146                n_accepted += 1;
147                accepted[i] = true;
148            }
149
150            samples.row_mut(i).assign(&current_state);
151            log_probs[i] = current_log_prob;
152
153            // Adapt step size
154            if self.adapt_stepsize && i < self.adaptation_window {
155                stepsize = self.adapt_stepsize_simple(stepsize, accept, self.target_acceptance);
156            }
157        }
158
159        let acceptance_rate = n_accepted as f64 / n_samples_ as f64;
160
161        Ok(HMCResult {
162            samples,
163            log_probabilities: log_probs,
164            accepted,
165            acceptance_rate,
166            final_stepsize: stepsize,
167            n_samples_,
168            ndim,
169        })
170    }
171
172    /// Sample momentum from multivariate normal
173    fn samplemomentum<R: Rng>(
174        &self,
175        mass_matrix: &Array2<f64>,
176        rng: &mut R,
177    ) -> Result<Array1<f64>> {
178        let ndim = mass_matrix.nrows();
179        let mut momentum = Array1::zeros(ndim);
180
181        // Sample from N(0, mass_matrix)
182        for i in 0..ndim {
183            momentum[i] = rng.random::<f64>() * 2.0 - 1.0; // Simplified - should use proper normal sampling
184        }
185
186        // Transform by Cholesky factor of mass _matrix (simplified)
187        let scaledmomentum = mass_matrix.dot(&momentum);
188        Ok(scaledmomentum)
189    }
190
191    /// Compute kinetic energy
192    fn kinetic_energy(&self, momentum: &Array1<f64>, mass_matrix_inv: &Array2<f64>) -> Result<f64> {
193        let kinetic = 0.5 * momentum.dot(&mass_matrix_inv.dot(momentum));
194        Ok(kinetic)
195    }
196
197    /// Leapfrog integration for Hamiltonian dynamics
198    fn leapfrog_integration<D: LogDensity>(
199        &self,
200        initial_position: &Array1<f64>,
201        initialmomentum: &Array1<f64>,
202        target: &D,
203        mass_matrix_inv: &Array2<f64>,
204        stepsize: f64,
205    ) -> Result<(Array1<f64>, Array1<f64>, f64)> {
206        let mut _position = initial_position.clone();
207        let mut momentum = initialmomentum.clone();
208
209        // Half step for momentum
210        if let Some(grad) = target.gradient(&_position.view())? {
211            momentum = &momentum + &(stepsize * 0.5 * &grad);
212        } else {
213            return Err(StatsError::ComputationError(
214                "Gradient required for HMC but not available".to_string(),
215            ));
216        }
217
218        // Full steps
219        for _ in 0..self.n_steps {
220            // Full step for _position
221            _position = &_position + &(stepsize * &mass_matrix_inv.dot(&momentum));
222
223            // Full step for momentum (except last)
224            if let Some(grad) = target.gradient(&_position.view())? {
225                momentum = &momentum + &(stepsize * &grad);
226            }
227        }
228
229        // Final half step for momentum
230        if let Some(grad) = target.gradient(&_position.view())? {
231            momentum = &momentum + &(stepsize * 0.5 * &grad);
232        }
233
234        // Negate momentum for reversibility
235        momentum = -momentum;
236
237        let final_log_prob = target.log_density(&_position.view())?;
238
239        Ok((_position, momentum, final_log_prob))
240    }
241
242    /// Invert mass matrix (simplified)
243    fn invert_mass_matrix(&self, mass_matrix: &Array2<f64>) -> Result<Array2<f64>> {
244        // Simplified inversion - in practice use proper _matrix inversion
245        if mass_matrix.is_square() {
246            // For now, assume diagonal mass _matrix
247            let mut inv = Array2::zeros(mass_matrix.raw_dim());
248            for i in 0..mass_matrix.nrows() {
249                if mass_matrix[[i, i]].abs() < 1e-12 {
250                    return Err(StatsError::ComputationError(
251                        "Mass _matrix is singular".to_string(),
252                    ));
253                }
254                inv[[i, i]] = 1.0 / mass_matrix[[i, i]];
255            }
256            Ok(inv)
257        } else {
258            Err(StatsError::ComputationError(
259                "Mass _matrix must be square".to_string(),
260            ))
261        }
262    }
263
264    /// Simple step size adaptation
265    fn adapt_stepsize_simple(
266        &self,
267        current_stepsize: f64,
268        accepted: bool,
269        target_rate: f64,
270    ) -> f64 {
271        let acceptance_rate = if accepted { 1.0 } else { 0.0 };
272        let adaptation_rate = 0.01;
273
274        if acceptance_rate > target_rate {
275            current_stepsize * (1.0 + adaptation_rate)
276        } else {
277            current_stepsize * (1.0 - adaptation_rate)
278        }
279    }
280}
281
282/// Result of HMC sampling
283#[derive(Debug, Clone)]
284pub struct HMCResult {
285    /// Generated samples (n_samples_ × ndim)
286    pub samples: Array2<f64>,
287    /// Log probabilities for each sample
288    pub log_probabilities: Array1<f64>,
289    /// Acceptance indicators for each sample
290    pub accepted: Array1<bool>,
291    /// Overall acceptance rate
292    pub acceptance_rate: f64,
293    /// Final adapted step size
294    pub final_stepsize: f64,
295    /// Number of samples
296    pub n_samples_: usize,
297    /// Number of dimensions
298    pub ndim: usize,
299}
300
301/// No-U-Turn Sampler (NUTS)
302///
303/// NUTS is an extension of HMC that automatically tunes the number of leapfrog steps
304/// by stopping when the trajectory starts to double back on itself.
305#[derive(Debug, Clone)]
306pub struct NoUTurnSampler {
307    /// Initial step size
308    pub initial_stepsize: f64,
309    /// Maximum tree depth
310    pub max_tree_depth: usize,
311    /// Target acceptance rate for step size adaptation
312    pub target_acceptance: f64,
313    /// Step size adaptation parameter
314    pub gamma: f64,
315    /// Step size adaptation parameter
316    pub t0: f64,
317    /// Step size adaptation parameter
318    pub kappa: f64,
319    /// Random seed
320    pub seed: Option<u64>,
321}
322
323impl Default for NoUTurnSampler {
324    fn default() -> Self {
325        Self::new()
326    }
327}
328
329impl NoUTurnSampler {
330    /// Create a new NUTS sampler
331    pub fn new() -> Self {
332        Self {
333            initial_stepsize: 1.0,
334            max_tree_depth: 10,
335            target_acceptance: 0.8,
336            gamma: 0.05,
337            t0: 10.0,
338            kappa: 0.75,
339            seed: None,
340        }
341    }
342
343    /// Set initial step size
344    pub fn with_stepsize(mut self, stepsize: f64) -> Self {
345        self.initial_stepsize = stepsize;
346        self
347    }
348
349    /// Set maximum tree depth
350    pub fn with_max_depth(mut self, depth: usize) -> Self {
351        self.max_tree_depth = depth;
352        self
353    }
354
355    /// Sample using NUTS algorithm
356    pub fn sample<D: LogDensity>(
357        &self,
358        target: &D,
359        n_samples_: usize,
360        initial_state: ArrayView1<f64>,
361    ) -> Result<NUTSResult> {
362        check_positive(n_samples_, "n_samples_")?;
363        checkarray_finite(&initial_state, "initial_state")?;
364
365        let mut rng = match self.seed {
366            Some(seed) => StdRng::seed_from_u64(seed),
367            None => StdRng::from_rng(&mut thread_rng()),
368        };
369
370        let ndim = target.ndim();
371        let mut samples = Array2::zeros((n_samples_, ndim));
372        let mut log_probs = Array1::zeros(n_samples_);
373        let mut treesizes = Array1::zeros(n_samples_);
374
375        let mut current_state = initial_state.to_owned();
376        let mut current_log_prob;
377
378        let mut stepsize = self.initial_stepsize;
379        let mut stepsize_bar = self.initial_stepsize;
380        let mut h_bar = 0.0;
381
382        for i in 0..n_samples_ {
383            // Sample momentum
384            let momentum = self.samplemomentum(ndim, &mut rng);
385
386            // Build tree and sample
387            let (new_state, new_log_prob, treesize) =
388                self.build_tree(&current_state, &momentum, target, stepsize, &mut rng)?;
389
390            current_state = new_state;
391            current_log_prob = new_log_prob;
392
393            samples.row_mut(i).assign(&current_state);
394            log_probs[i] = current_log_prob;
395            treesizes[i] = treesize as f64;
396
397            // Adapt step size using dual averaging
398            if i < n_samples_ / 2 {
399                let acceptance_prob = 1.0; // Simplified - should track actual acceptance
400                h_bar = (1.0 - 1.0 / (i as f64 + self.t0)) * h_bar
401                    + (self.target_acceptance - acceptance_prob) / (i as f64 + self.t0);
402
403                stepsize = self.initial_stepsize * (-h_bar).exp();
404
405                let eta = (i as f64 + 1.0).powf(-self.kappa);
406                stepsize_bar = (-eta * h_bar).exp() * (i as f64 + 1.0).powf(-self.kappa)
407                    + (1.0 - (i as f64 + 1.0).powf(-self.kappa)) * stepsize_bar;
408            } else {
409                stepsize = stepsize_bar;
410            }
411        }
412
413        Ok(NUTSResult {
414            samples,
415            log_probabilities: log_probs,
416            treesizes,
417            final_stepsize: stepsize,
418            n_samples_,
419            ndim,
420        })
421    }
422
423    /// Sample momentum from standard normal
424    fn samplemomentum<R: Rng>(&self, ndim: usize, rng: &mut R) -> Array1<f64> {
425        let mut momentum = Array1::zeros(ndim);
426        for i in 0..ndim {
427            // Simplified normal sampling using Box-Muller
428            let u1: f64 = rng.random();
429            let u2: f64 = rng.random();
430            momentum[i] = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
431        }
432        momentum
433    }
434
435    /// Build binary tree for NUTS
436    fn build_tree<D: LogDensity, R: Rng>(
437        &self,
438        position: &Array1<f64>,
439        momentum: &Array1<f64>,
440        target: &D,
441        stepsize: f64,
442        rng: &mut R,
443    ) -> Result<(Array1<f64>, f64, usize)> {
444        // Simplified tree building - this is a basic implementation
445        // A full NUTS implementation would be much more complex
446
447        let mut current_pos = position.clone();
448        let mut current_mom = momentum.clone();
449        let _current_log_prob = target.log_density(&current_pos.view())?;
450
451        // Take a few leapfrog steps (simplified)
452        let n_steps = 2_usize.pow(rng.random_range(1..self.max_tree_depth.min(4) + 1) as u32);
453
454        for _ in 0..n_steps {
455            // Simplified leapfrog step
456            if let Some(grad) = target.gradient(&current_pos.view())? {
457                current_mom = &current_mom + &(stepsize * 0.5 * &grad);
458                current_pos = &current_pos + &(stepsize * &current_mom);
459                if let Some(grad) = target.gradient(&current_pos.view())? {
460                    current_mom = &current_mom + &(stepsize * 0.5 * &grad);
461                }
462            }
463        }
464
465        let new_log_prob = target.log_density(&current_pos.view())?;
466
467        Ok((current_pos, new_log_prob, n_steps))
468    }
469}
470
471/// Result of NUTS sampling
472#[derive(Debug, Clone)]
473pub struct NUTSResult {
474    /// Generated samples (n_samples_ × ndim)
475    pub samples: Array2<f64>,
476    /// Log probabilities for each sample
477    pub log_probabilities: Array1<f64>,
478    /// Tree sizes for each iteration
479    pub treesizes: Array1<f64>,
480    /// Final adapted step size
481    pub final_stepsize: f64,
482    /// Number of samples
483    pub n_samples_: usize,
484    /// Number of dimensions
485    pub ndim: usize,
486}
487
488/// Adaptive Metropolis sampler
489///
490/// Adapts the proposal covariance based on the sample history to improve efficiency.
491#[derive(Debug, Clone)]
492pub struct AdaptiveMetropolis {
493    /// Initial proposal covariance
494    pub initial_covariance: Option<Array2<f64>>,
495    /// Adaptation start after this many samples
496    pub adaptation_start: usize,
497    /// Scaling factor for covariance
498    pub scale_factor: f64,
499    /// Small constant to prevent singularity
500    pub epsilon: f64,
501    /// Random seed
502    pub seed: Option<u64>,
503}
504
505impl Default for AdaptiveMetropolis {
506    fn default() -> Self {
507        Self::new()
508    }
509}
510
511impl AdaptiveMetropolis {
512    /// Create a new adaptive Metropolis sampler
513    pub fn new() -> Self {
514        Self {
515            initial_covariance: None,
516            adaptation_start: 100,
517            scale_factor: 2.38 * 2.38, // Optimal scaling for multivariate normal
518            epsilon: 1e-6,
519            seed: None,
520        }
521    }
522
523    /// Set initial covariance matrix
524    pub fn with_covariance(mut self, cov: Array2<f64>) -> Self {
525        self.initial_covariance = Some(cov);
526        self
527    }
528
529    /// Sample using adaptive Metropolis
530    pub fn sample<D: LogDensity>(
531        &self,
532        target: &D,
533        n_samples_: usize,
534        initial_state: ArrayView1<f64>,
535    ) -> Result<AdaptiveMetropolisResult> {
536        check_positive(n_samples_, "n_samples_")?;
537        checkarray_finite(&initial_state, "initial_state")?;
538
539        let mut rng = match self.seed {
540            Some(seed) => StdRng::seed_from_u64(seed),
541            None => StdRng::from_rng(&mut thread_rng()),
542        };
543
544        let ndim = target.ndim();
545        let mut samples = Array2::zeros((n_samples_, ndim));
546        let mut log_probs = Array1::zeros(n_samples_);
547        let mut accepted = Array1::from_elem(n_samples_, false);
548
549        let mut current_state = initial_state.to_owned();
550        let mut current_log_prob = target.log_density(&current_state.view())?;
551
552        // Initialize covariance
553        let mut covariance = self
554            .initial_covariance
555            .clone()
556            .unwrap_or_else(|| Array2::eye(ndim));
557
558        let mut sample_mean = Array1::zeros(ndim);
559        let mut sample_cov = Array2::zeros((ndim, ndim));
560        let mut n_adapted = 0;
561        let mut n_accepted = 0;
562
563        for i in 0..n_samples_ {
564            // Generate proposal
565            let proposal = self.generate_proposal(&current_state, &covariance, &mut rng)?;
566            let proposal_log_prob = target.log_density(&proposal.view())?;
567
568            // Metropolis acceptance
569            let log_accept_prob = proposal_log_prob - current_log_prob;
570            let accept = log_accept_prob.exp() > rng.random::<f64>();
571
572            if accept {
573                current_state = proposal;
574                current_log_prob = proposal_log_prob;
575                n_accepted += 1;
576                accepted[i] = true;
577            }
578
579            samples.row_mut(i).assign(&current_state);
580            log_probs[i] = current_log_prob;
581
582            // Update adaptation statistics
583            if i >= self.adaptation_start {
584                n_adapted += 1;
585                let delta = &current_state - &sample_mean;
586                sample_mean = &sample_mean + &delta / (n_adapted as f64);
587
588                // Update sample covariance (Welford's algorithm)
589                let delta2 = &current_state - &sample_mean;
590                for j in 0..ndim {
591                    for k in 0..ndim {
592                        sample_cov[[j, k]] += delta[j] * delta2[k];
593                    }
594                }
595
596                // Update proposal covariance
597                if n_adapted > 1 {
598                    covariance =
599                        &sample_cov / (n_adapted - 1) as f64 * self.scale_factor / ndim as f64;
600
601                    // Add small diagonal term for numerical stability
602                    for j in 0..ndim {
603                        covariance[[j, j]] += self.epsilon;
604                    }
605                }
606            }
607        }
608
609        let acceptance_rate = n_accepted as f64 / n_samples_ as f64;
610
611        Ok(AdaptiveMetropolisResult {
612            samples,
613            log_probabilities: log_probs,
614            accepted,
615            acceptance_rate,
616            final_covariance: covariance,
617            n_samples_,
618            ndim,
619        })
620    }
621
622    /// Generate proposal using multivariate normal
623    fn generate_proposal<R: Rng>(
624        &self,
625        current: &Array1<f64>,
626        covariance: &Array2<f64>,
627        rng: &mut R,
628    ) -> Result<Array1<f64>> {
629        let ndim = current.len();
630
631        // Sample from N(0, I)
632        let mut z = Array1::zeros(ndim);
633        for i in 0..ndim {
634            let u1: f64 = rng.random();
635            let u2: f64 = rng.random();
636            z[i] = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
637        }
638
639        // Transform to N(current, covariance) using Cholesky decomposition
640        // Simplified: assume diagonal covariance for now
641        let mut proposal = current.clone();
642        for i in 0..ndim {
643            proposal[i] += z[i] * covariance[[i, i]].sqrt();
644        }
645
646        Ok(proposal)
647    }
648}
649
650/// Result of adaptive Metropolis sampling
651#[derive(Debug, Clone)]
652pub struct AdaptiveMetropolisResult {
653    /// Generated samples (n_samples_ × ndim)
654    pub samples: Array2<f64>,
655    /// Log probabilities for each sample
656    pub log_probabilities: Array1<f64>,
657    /// Acceptance indicators for each sample
658    pub accepted: Array1<bool>,
659    /// Overall acceptance rate
660    pub acceptance_rate: f64,
661    /// Final adapted covariance matrix
662    pub final_covariance: Array2<f64>,
663    /// Number of samples
664    pub n_samples_: usize,
665    /// Number of dimensions
666    pub ndim: usize,
667}
668
669/// Parallel tempering (replica exchange) MCMC
670///
671/// Runs multiple chains at different temperatures to improve mixing and exploration
672/// of multimodal distributions.
673#[derive(Debug, Clone)]
674pub struct ParallelTempering {
675    /// Temperature ladder
676    pub temperatures: Array1<f64>,
677    /// Base sampler for each chain
678    pub stepsize: f64,
679    /// Number of steps between swap attempts
680    pub swap_interval: usize,
681    /// Random seed
682    pub seed: Option<u64>,
683}
684
685impl ParallelTempering {
686    /// Create a new parallel tempering sampler
687    pub fn new(temperatures: Array1<f64>, stepsize: f64) -> Result<Self> {
688        checkarray_finite(&temperatures, "temperatures")?;
689        check_positive(stepsize, "stepsize")?;
690
691        for &temp in temperatures.iter() {
692            if temp <= 0.0 {
693                return Err(StatsError::InvalidArgument(
694                    "All _temperatures must be positive".to_string(),
695                ));
696            }
697        }
698
699        Ok(Self {
700            temperatures,
701            stepsize,
702            swap_interval: 10,
703            seed: None,
704        })
705    }
706
707    /// Sample using parallel tempering
708    pub fn sample<D: LogDensity + Send + Sync>(
709        &self,
710        target: &D,
711        n_samples_: usize,
712        initial_states: ArrayView2<f64>,
713    ) -> Result<ParallelTemperingResult> {
714        check_positive(n_samples_, "n_samples_")?;
715        checkarray_finite(&initial_states, "initial_states")?;
716
717        let n_chains = self.temperatures.len();
718        let ndim = target.ndim();
719
720        if initial_states.nrows() != n_chains || initial_states.ncols() != ndim {
721            return Err(StatsError::DimensionMismatch(format!(
722                "initial_states shape ({}, {}) must match (n_chains={}, ndim={})",
723                initial_states.nrows(),
724                initial_states.ncols(),
725                n_chains,
726                ndim
727            )));
728        }
729
730        let mut rng = match self.seed {
731            Some(seed) => StdRng::seed_from_u64(seed),
732            None => StdRng::from_rng(&mut thread_rng()),
733        };
734
735        // Initialize chains
736        let mut chain_samples_ = vec![Array2::zeros((n_samples_, ndim)); n_chains];
737        let mut chain_log_probs = vec![Array1::zeros(n_samples_); n_chains];
738        let mut current_states: Vec<Array1<f64>> = initial_states
739            .rows()
740            .into_iter()
741            .map(|row| row.to_owned())
742            .collect();
743
744        let mut current_log_probs = vec![0.0; n_chains];
745        for (i, state) in current_states.iter().enumerate() {
746            current_log_probs[i] = target.log_density(&state.view())?;
747        }
748
749        let mut n_swaps_attempted = 0;
750        let mut n_swaps_accepted = 0;
751
752        for sample_idx in 0..n_samples_ {
753            // Update each chain with Metropolis step
754            for chain_idx in 0..n_chains {
755                let temp = self.temperatures[chain_idx];
756                let (new_state, new_log_prob) = self.metropolis_step(
757                    &current_states[chain_idx],
758                    current_log_probs[chain_idx],
759                    target,
760                    temp,
761                    &mut rng,
762                )?;
763
764                current_states[chain_idx] = new_state;
765                current_log_probs[chain_idx] = new_log_prob;
766
767                chain_samples_[chain_idx]
768                    .row_mut(sample_idx)
769                    .assign(&current_states[chain_idx]);
770                chain_log_probs[chain_idx][sample_idx] = current_log_probs[chain_idx];
771            }
772
773            // Attempt swaps between adjacent temperatures
774            if sample_idx % self.swap_interval == 0 && n_chains > 1 {
775                for i in 0..n_chains - 1 {
776                    n_swaps_attempted += 1;
777
778                    let temp_i = self.temperatures[i];
779                    let temp_j = self.temperatures[i + 1];
780                    let log_prob_i = current_log_probs[i];
781                    let log_prob_j = current_log_probs[i + 1];
782
783                    // Compute swap probability
784                    let beta_i = 1.0 / temp_i;
785                    let beta_j = 1.0 / temp_j;
786                    let log_swap_prob = (beta_i - beta_j) * (log_prob_j - log_prob_i);
787
788                    if log_swap_prob.exp() > rng.random::<f64>() {
789                        // Accept swap
790                        current_states.swap(i, i + 1);
791                        current_log_probs.swap(i, i + 1);
792                        n_swaps_accepted += 1;
793                    }
794                }
795            }
796        }
797
798        let swap_acceptance_rate = if n_swaps_attempted > 0 {
799            n_swaps_accepted as f64 / n_swaps_attempted as f64
800        } else {
801            0.0
802        };
803
804        Ok(ParallelTemperingResult {
805            chain_samples_,
806            chain_log_probabilities: chain_log_probs,
807            temperatures: self.temperatures.clone(),
808            swap_acceptance_rate,
809            n_samples_,
810            n_chains,
811            ndim,
812        })
813    }
814
815    /// Single Metropolis step for a tempered chain
816    fn metropolis_step<D: LogDensity, R: Rng>(
817        &self,
818        current: &Array1<f64>,
819        current_log_prob: f64,
820        target: &D,
821        temperature: f64,
822        rng: &mut R,
823    ) -> Result<(Array1<f64>, f64)> {
824        // Simple random walk proposal
825        let ndim = current.len();
826        let mut proposal = current.clone();
827
828        for i in 0..ndim {
829            let u1: f64 = rng.random();
830            let u2: f64 = rng.random();
831            let normal_sample = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
832            proposal[i] += self.stepsize * normal_sample;
833        }
834
835        let proposal_log_prob = target.log_density(&proposal.view())?;
836
837        // Tempered acceptance probability
838        let log_accept_prob = (proposal_log_prob - current_log_prob) / temperature;
839
840        if log_accept_prob.exp() > rng.random::<f64>() {
841            Ok((proposal, proposal_log_prob))
842        } else {
843            Ok((current.clone(), current_log_prob))
844        }
845    }
846}
847
848/// Result of parallel tempering sampling
849#[derive(Debug, Clone)]
850pub struct ParallelTemperingResult {
851    /// Samples from each chain (one per temperature)
852    pub chain_samples_: Vec<Array2<f64>>,
853    /// Log probabilities for each chain
854    pub chain_log_probabilities: Vec<Array1<f64>>,
855    /// Temperature ladder used
856    pub temperatures: Array1<f64>,
857    /// Rate of accepted temperature swaps
858    pub swap_acceptance_rate: f64,
859    /// Number of samples per chain
860    pub n_samples_: usize,
861    /// Number of chains (temperatures)
862    pub n_chains: usize,
863    /// Number of dimensions
864    pub ndim: usize,
865}
866
867impl ParallelTemperingResult {
868    /// Get samples from the cold chain (temperature = 1.0)
869    pub fn cold_chain_samples_(&self) -> Result<&Array2<f64>> {
870        // Find chain with temperature closest to 1.0
871        let mut min_diff = f64::INFINITY;
872        let mut cold_idx = 0;
873
874        for (i, &temp) in self.temperatures.iter().enumerate() {
875            let diff = (temp - 1.0).abs();
876            if diff < min_diff {
877                min_diff = diff;
878                cold_idx = i;
879            }
880        }
881
882        Ok(&self.chain_samples_[cold_idx])
883    }
884}
885
886/// Example multivariate normal target distribution for testing
887#[derive(Debug, Clone)]
888pub struct MultivariateNormal {
889    pub mean: Array1<f64>,
890    pub precision: Array2<f64>,
891    pub log_det_precision: f64,
892}
893
894impl MultivariateNormal {
895    pub fn new(mean: Array1<f64>, covariance: Array2<f64>) -> Result<Self> {
896        checkarray_finite(&mean, "mean")?;
897        checkarray_finite(&covariance, "covariance")?;
898
899        // Simplified precision computation (should use proper matrix inversion)
900        let precision = Array2::eye(mean.len()); // Placeholder
901        let log_det_precision = 0.0; // Placeholder
902
903        Ok(Self {
904            mean,
905            precision,
906            log_det_precision,
907        })
908    }
909}
910
911impl LogDensity for MultivariateNormal {
912    fn log_density(&self, theta: &ArrayView1<f64>) -> Result<f64> {
913        let diff = theta - &self.mean;
914        let quad_form = diff.dot(&self.precision.dot(&diff));
915        let log_prob = -0.5 * quad_form + 0.5 * self.log_det_precision
916            - 0.5 * self.mean.len() as f64 * (2.0 * std::f64::consts::PI).ln();
917        Ok(log_prob)
918    }
919
920    fn gradient(&self, theta: &ArrayView1<f64>) -> Result<Option<Array1<f64>>> {
921        let diff = theta - &self.mean;
922        let grad = -self.precision.dot(&diff);
923        Ok(Some(grad))
924    }
925
926    fn ndim(&self) -> usize {
927        self.mean.len()
928    }
929}