Skip to main content

scirs2_stats/mcmc/
metropolis.rs

1//! Metropolis-Hastings algorithm for MCMC sampling
2
3use crate::error::{StatsError, StatsResult as Result};
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::random::{Distribution, Uniform};
6use scirs2_core::validation::*;
7use scirs2_linalg::{det, inv};
8use std::fmt::Debug;
9
10/// Target distribution trait for MCMC sampling
11pub trait TargetDistribution: Send + Sync {
12    /// Compute the log probability density at a given point
13    fn log_density(&self, x: &Array1<f64>) -> f64;
14
15    /// Get the dimensionality of the distribution
16    fn dim(&self) -> usize;
17}
18
19/// Proposal distribution trait for Metropolis-Hastings
20pub trait ProposalDistribution: Send + Sync {
21    /// Sample a new proposal given the current state
22    fn sample<R: scirs2_core::random::Rng + ?Sized>(
23        &self,
24        current: &Array1<f64>,
25        rng: &mut R,
26    ) -> Array1<f64>;
27
28    /// Compute the log density ratio q(x|y) / q(y|x) for asymmetric proposals
29    fn log_ratio(from: &Array1<f64>, to: &Array1<f64>) -> f64 {
30        0.0 // Default _to symmetric proposal
31    }
32}
33
34/// Random walk proposal with normal distribution
35#[derive(Debug, Clone)]
36pub struct RandomWalkProposal {
37    /// Step size (standard deviation)
38    pub stepsize: f64,
39}
40
41impl RandomWalkProposal {
42    /// Create a new random walk proposal
43    pub fn new(stepsize: f64) -> Result<Self> {
44        check_positive(stepsize, "stepsize")?;
45        Ok(Self { stepsize })
46    }
47}
48
49impl ProposalDistribution for RandomWalkProposal {
50    fn sample<R: scirs2_core::random::Rng + ?Sized>(
51        &self,
52        current: &Array1<f64>,
53        rng: &mut R,
54    ) -> Array1<f64> {
55        use scirs2_core::random::Normal;
56        let normal = Normal::new(0.0, self.stepsize).expect("Operation failed");
57        current + Array1::from_shape_fn(current.len(), |_| normal.sample(rng))
58    }
59}
60
61/// Metropolis-Hastings sampler
62pub struct MetropolisHastings<T: TargetDistribution, P: ProposalDistribution> {
63    /// Target distribution to sample from
64    pub target: T,
65    /// Proposal distribution
66    pub proposal: P,
67    /// Current state
68    pub current: Array1<f64>,
69    /// Current log density
70    pub current_log_density: f64,
71    /// Number of accepted proposals
72    pub n_accepted: usize,
73    /// Total number of proposals
74    pub n_proposed: usize,
75}
76
77impl<T: TargetDistribution, P: ProposalDistribution> MetropolisHastings<T, P> {
78    /// Create a new Metropolis-Hastings sampler
79    pub fn new(target: T, proposal: P, initial: Array1<f64>) -> Result<Self> {
80        checkarray_finite(&initial, "initial")?;
81        if initial.len() != target.dim() {
82            return Err(StatsError::DimensionMismatch(format!(
83                "initial dimension ({}) must match _target dimension ({})",
84                initial.len(),
85                target.dim()
86            )));
87        }
88
89        let current_log_density = target.log_density(&initial);
90
91        Ok(Self {
92            target,
93            proposal,
94            current: initial,
95            current_log_density,
96            n_accepted: 0,
97            n_proposed: 0,
98        })
99    }
100
101    /// Perform one step of the Metropolis-Hastings algorithm
102    pub fn step<R: scirs2_core::random::Rng + ?Sized>(&mut self, rng: &mut R) -> Array1<f64> {
103        // Propose new state
104        let proposed = self.proposal.sample(&self.current, rng);
105        let proposed_log_density = self.target.log_density(&proposed);
106
107        // Compute acceptance ratio
108        let log_ratio = proposed_log_density - self.current_log_density
109            + P::log_ratio(&self.current, &proposed);
110
111        // Accept or reject
112        self.n_proposed += 1;
113        let u: f64 = Uniform::new(0.0, 1.0)
114            .expect("Operation failed")
115            .sample(rng);
116        if u.ln() < log_ratio {
117            self.current = proposed;
118            self.current_log_density = proposed_log_density;
119            self.n_accepted += 1;
120        }
121
122        self.current.clone()
123    }
124
125    /// Sample multiple states from the distribution
126    pub fn sample<R: scirs2_core::random::Rng + ?Sized>(
127        &mut self,
128        nsamples_: usize,
129        rng: &mut R,
130    ) -> Array2<f64> {
131        let dim = self.current.len();
132        let mut samples = Array2::zeros((nsamples_, dim));
133
134        for i in 0..nsamples_ {
135            let sample = self.step(rng);
136            samples.row_mut(i).assign(&sample);
137        }
138
139        samples
140    }
141
142    /// Sample with thinning to reduce autocorrelation
143    pub fn sample_thinned<R: scirs2_core::random::Rng + ?Sized>(
144        &mut self,
145        n_samples_: usize,
146        thin: usize,
147        rng: &mut R,
148    ) -> Result<Array2<f64>> {
149        check_positive(thin, "thin")?;
150
151        let dim = self.current.len();
152        let mut samples = Array2::zeros((n_samples_, dim));
153
154        for i in 0..n_samples_ {
155            // Take thin steps but only keep the last one
156            for _ in 0..thin {
157                self.step(rng);
158            }
159            samples.row_mut(i).assign(&self.current);
160        }
161
162        Ok(samples)
163    }
164
165    /// Get the acceptance rate
166    pub fn acceptance_rate(&self) -> f64 {
167        if self.n_proposed == 0 {
168            0.0
169        } else {
170            self.n_accepted as f64 / self.n_proposed as f64
171        }
172    }
173
174    /// Reset counters
175    pub fn reset_counters(&mut self) {
176        self.n_accepted = 0;
177        self.n_proposed = 0;
178    }
179}
180
181/// Adaptive Metropolis-Hastings that adjusts proposal step size
182pub struct AdaptiveMetropolisHastings<T: TargetDistribution> {
183    /// Base sampler
184    pub sampler: MetropolisHastings<T, RandomWalkProposal>,
185    /// Target acceptance rate
186    pub target_rate: f64,
187    /// Adaptation rate
188    pub adaptation_rate: f64,
189    /// Minimum step size
190    pub min_stepsize: f64,
191    /// Maximum step size
192    pub max_stepsize: f64,
193}
194
195impl<T: TargetDistribution> AdaptiveMetropolisHastings<T> {
196    /// Create a new adaptive Metropolis-Hastings sampler
197    pub fn new(
198        target: T,
199        initial: Array1<f64>,
200        initial_stepsize: f64,
201        target_rate: f64,
202    ) -> Result<Self> {
203        check_probability(target_rate, "target_rate")?;
204        check_positive(initial_stepsize, "initial_stepsize")?;
205
206        let proposal = RandomWalkProposal::new(initial_stepsize)?;
207        let sampler = MetropolisHastings::new(target, proposal, initial)?;
208
209        Ok(Self {
210            sampler,
211            target_rate,
212            adaptation_rate: 0.05,
213            min_stepsize: 1e-6,
214            max_stepsize: 10.0,
215        })
216    }
217
218    /// Perform one adaptive step
219    pub fn step<R: scirs2_core::random::Rng + ?Sized>(&mut self, rng: &mut R) -> Array1<f64> {
220        let sample = self.sampler.step(rng);
221
222        // Adapt step size based on acceptance rate
223        if self.sampler.n_proposed.is_multiple_of(100) && self.sampler.n_proposed > 0 {
224            let current_rate = self.sampler.acceptance_rate();
225            let adjustment = 1.0 + self.adaptation_rate * (current_rate - self.target_rate);
226
227            let new_stepsize = (self.sampler.proposal.stepsize * adjustment)
228                .max(self.min_stepsize)
229                .min(self.max_stepsize);
230
231            self.sampler.proposal.stepsize = new_stepsize;
232        }
233
234        sample
235    }
236
237    /// Run adaptation phase
238    pub fn adapt<R: scirs2_core::random::Rng + ?Sized>(
239        &mut self,
240        nsteps: usize,
241        rng: &mut R,
242    ) -> Result<()> {
243        check_positive(nsteps, "n_steps")?;
244
245        for _ in 0..nsteps {
246            self.step(rng);
247        }
248
249        // Reset counters after adaptation
250        self.sampler.reset_counters();
251        Ok(())
252    }
253}
254
255// Example implementations for common distributions
256
257/// Multivariate normal target distribution
258#[derive(Debug, Clone)]
259pub struct MultivariateNormalTarget {
260    /// Mean vector
261    pub mean: Array1<f64>,
262    /// Precision matrix (inverse covariance)
263    pub precision: Array2<f64>,
264    /// Log normalizing constant
265    pub log_norm_const: f64,
266}
267
268impl MultivariateNormalTarget {
269    /// Create a new multivariate normal target
270    pub fn new(mean: Array1<f64>, covariance: Array2<f64>) -> Result<Self> {
271        checkarray_finite(&mean, "mean")?;
272        checkarray_finite(&covariance, "covariance")?;
273        if covariance.nrows() != mean.len() || covariance.ncols() != mean.len() {
274            return Err(StatsError::DimensionMismatch(format!(
275                "covariance shape ({}, {}) must be ({}, {})",
276                covariance.nrows(),
277                covariance.ncols(),
278                mean.len(),
279                mean.len()
280            )));
281        }
282
283        // Compute precision matrix (inverse of covariance)
284        let precision = inv(&covariance.view(), None).map_err(|e| {
285            StatsError::ComputationError(format!("Failed to invert covariance matrix: {}", e))
286        })?;
287
288        // Compute determinant
289        let det_value = det(&covariance.view(), None).map_err(|e| {
290            StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
291        })?;
292
293        if det_value <= 0.0 {
294            return Err(StatsError::InvalidArgument(
295                "Covariance matrix must be positive definite".to_string(),
296            ));
297        }
298
299        let d = mean.len() as f64;
300        let log_norm_const = -0.5 * (d * (2.0 * std::f64::consts::PI).ln() + det_value.ln());
301
302        Ok(Self {
303            mean,
304            precision,
305            log_norm_const,
306        })
307    }
308}
309
310impl TargetDistribution for MultivariateNormalTarget {
311    fn log_density(&self, x: &Array1<f64>) -> f64 {
312        let diff = x - &self.mean;
313        let quad_form = diff.dot(&self.precision.dot(&diff));
314        self.log_norm_const - 0.5 * quad_form
315    }
316
317    fn dim(&self) -> usize {
318        self.mean.len()
319    }
320}
321
322/// Custom target distribution from a log density function
323pub struct CustomTarget<F> {
324    /// Log density function
325    pub log_density_fn: F,
326    /// Dimensionality
327    pub dim: usize,
328}
329
330impl<F> CustomTarget<F> {
331    /// Create a new custom target distribution
332    pub fn new(dim: usize, log_densityfn: F) -> Result<Self> {
333        check_positive(dim, "dim")?;
334        Ok(Self {
335            log_density_fn: log_densityfn,
336            dim,
337        })
338    }
339}
340
341impl<F> TargetDistribution for CustomTarget<F>
342where
343    F: Fn(&Array1<f64>) -> f64 + Send + Sync,
344{
345    fn log_density(&self, x: &Array1<f64>) -> f64 {
346        (self.log_density_fn)(x)
347    }
348
349    fn dim(&self) -> usize {
350        self.dim
351    }
352}