1use crate::error::{StatsError, StatsResult as Result};
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::random::{Distribution, Uniform};
6use scirs2_core::validation::*;
7use scirs2_linalg::{det, inv};
8use std::fmt::Debug;
9
10pub trait TargetDistribution: Send + Sync {
12 fn log_density(&self, x: &Array1<f64>) -> f64;
14
15 fn dim(&self) -> usize;
17}
18
19pub trait ProposalDistribution: Send + Sync {
21 fn sample<R: scirs2_core::random::Rng + ?Sized>(
23 &self,
24 current: &Array1<f64>,
25 rng: &mut R,
26 ) -> Array1<f64>;
27
28 fn log_ratio(from: &Array1<f64>, to: &Array1<f64>) -> f64 {
30 0.0 }
32}
33
34#[derive(Debug, Clone)]
36pub struct RandomWalkProposal {
37 pub stepsize: f64,
39}
40
41impl RandomWalkProposal {
42 pub fn new(stepsize: f64) -> Result<Self> {
44 check_positive(stepsize, "stepsize")?;
45 Ok(Self { stepsize })
46 }
47}
48
49impl ProposalDistribution for RandomWalkProposal {
50 fn sample<R: scirs2_core::random::Rng + ?Sized>(
51 &self,
52 current: &Array1<f64>,
53 rng: &mut R,
54 ) -> Array1<f64> {
55 use scirs2_core::random::Normal;
56 let normal = Normal::new(0.0, self.stepsize).expect("Operation failed");
57 current + Array1::from_shape_fn(current.len(), |_| normal.sample(rng))
58 }
59}
60
61pub struct MetropolisHastings<T: TargetDistribution, P: ProposalDistribution> {
63 pub target: T,
65 pub proposal: P,
67 pub current: Array1<f64>,
69 pub current_log_density: f64,
71 pub n_accepted: usize,
73 pub n_proposed: usize,
75}
76
77impl<T: TargetDistribution, P: ProposalDistribution> MetropolisHastings<T, P> {
78 pub fn new(target: T, proposal: P, initial: Array1<f64>) -> Result<Self> {
80 checkarray_finite(&initial, "initial")?;
81 if initial.len() != target.dim() {
82 return Err(StatsError::DimensionMismatch(format!(
83 "initial dimension ({}) must match _target dimension ({})",
84 initial.len(),
85 target.dim()
86 )));
87 }
88
89 let current_log_density = target.log_density(&initial);
90
91 Ok(Self {
92 target,
93 proposal,
94 current: initial,
95 current_log_density,
96 n_accepted: 0,
97 n_proposed: 0,
98 })
99 }
100
101 pub fn step<R: scirs2_core::random::Rng + ?Sized>(&mut self, rng: &mut R) -> Array1<f64> {
103 let proposed = self.proposal.sample(&self.current, rng);
105 let proposed_log_density = self.target.log_density(&proposed);
106
107 let log_ratio = proposed_log_density - self.current_log_density
109 + P::log_ratio(&self.current, &proposed);
110
111 self.n_proposed += 1;
113 let u: f64 = Uniform::new(0.0, 1.0)
114 .expect("Operation failed")
115 .sample(rng);
116 if u.ln() < log_ratio {
117 self.current = proposed;
118 self.current_log_density = proposed_log_density;
119 self.n_accepted += 1;
120 }
121
122 self.current.clone()
123 }
124
125 pub fn sample<R: scirs2_core::random::Rng + ?Sized>(
127 &mut self,
128 nsamples_: usize,
129 rng: &mut R,
130 ) -> Array2<f64> {
131 let dim = self.current.len();
132 let mut samples = Array2::zeros((nsamples_, dim));
133
134 for i in 0..nsamples_ {
135 let sample = self.step(rng);
136 samples.row_mut(i).assign(&sample);
137 }
138
139 samples
140 }
141
142 pub fn sample_thinned<R: scirs2_core::random::Rng + ?Sized>(
144 &mut self,
145 n_samples_: usize,
146 thin: usize,
147 rng: &mut R,
148 ) -> Result<Array2<f64>> {
149 check_positive(thin, "thin")?;
150
151 let dim = self.current.len();
152 let mut samples = Array2::zeros((n_samples_, dim));
153
154 for i in 0..n_samples_ {
155 for _ in 0..thin {
157 self.step(rng);
158 }
159 samples.row_mut(i).assign(&self.current);
160 }
161
162 Ok(samples)
163 }
164
165 pub fn acceptance_rate(&self) -> f64 {
167 if self.n_proposed == 0 {
168 0.0
169 } else {
170 self.n_accepted as f64 / self.n_proposed as f64
171 }
172 }
173
174 pub fn reset_counters(&mut self) {
176 self.n_accepted = 0;
177 self.n_proposed = 0;
178 }
179}
180
181pub struct AdaptiveMetropolisHastings<T: TargetDistribution> {
183 pub sampler: MetropolisHastings<T, RandomWalkProposal>,
185 pub target_rate: f64,
187 pub adaptation_rate: f64,
189 pub min_stepsize: f64,
191 pub max_stepsize: f64,
193}
194
195impl<T: TargetDistribution> AdaptiveMetropolisHastings<T> {
196 pub fn new(
198 target: T,
199 initial: Array1<f64>,
200 initial_stepsize: f64,
201 target_rate: f64,
202 ) -> Result<Self> {
203 check_probability(target_rate, "target_rate")?;
204 check_positive(initial_stepsize, "initial_stepsize")?;
205
206 let proposal = RandomWalkProposal::new(initial_stepsize)?;
207 let sampler = MetropolisHastings::new(target, proposal, initial)?;
208
209 Ok(Self {
210 sampler,
211 target_rate,
212 adaptation_rate: 0.05,
213 min_stepsize: 1e-6,
214 max_stepsize: 10.0,
215 })
216 }
217
218 pub fn step<R: scirs2_core::random::Rng + ?Sized>(&mut self, rng: &mut R) -> Array1<f64> {
220 let sample = self.sampler.step(rng);
221
222 if self.sampler.n_proposed.is_multiple_of(100) && self.sampler.n_proposed > 0 {
224 let current_rate = self.sampler.acceptance_rate();
225 let adjustment = 1.0 + self.adaptation_rate * (current_rate - self.target_rate);
226
227 let new_stepsize = (self.sampler.proposal.stepsize * adjustment)
228 .max(self.min_stepsize)
229 .min(self.max_stepsize);
230
231 self.sampler.proposal.stepsize = new_stepsize;
232 }
233
234 sample
235 }
236
237 pub fn adapt<R: scirs2_core::random::Rng + ?Sized>(
239 &mut self,
240 nsteps: usize,
241 rng: &mut R,
242 ) -> Result<()> {
243 check_positive(nsteps, "n_steps")?;
244
245 for _ in 0..nsteps {
246 self.step(rng);
247 }
248
249 self.sampler.reset_counters();
251 Ok(())
252 }
253}
254
255#[derive(Debug, Clone)]
259pub struct MultivariateNormalTarget {
260 pub mean: Array1<f64>,
262 pub precision: Array2<f64>,
264 pub log_norm_const: f64,
266}
267
268impl MultivariateNormalTarget {
269 pub fn new(mean: Array1<f64>, covariance: Array2<f64>) -> Result<Self> {
271 checkarray_finite(&mean, "mean")?;
272 checkarray_finite(&covariance, "covariance")?;
273 if covariance.nrows() != mean.len() || covariance.ncols() != mean.len() {
274 return Err(StatsError::DimensionMismatch(format!(
275 "covariance shape ({}, {}) must be ({}, {})",
276 covariance.nrows(),
277 covariance.ncols(),
278 mean.len(),
279 mean.len()
280 )));
281 }
282
283 let precision = inv(&covariance.view(), None).map_err(|e| {
285 StatsError::ComputationError(format!("Failed to invert covariance matrix: {}", e))
286 })?;
287
288 let det_value = det(&covariance.view(), None).map_err(|e| {
290 StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
291 })?;
292
293 if det_value <= 0.0 {
294 return Err(StatsError::InvalidArgument(
295 "Covariance matrix must be positive definite".to_string(),
296 ));
297 }
298
299 let d = mean.len() as f64;
300 let log_norm_const = -0.5 * (d * (2.0 * std::f64::consts::PI).ln() + det_value.ln());
301
302 Ok(Self {
303 mean,
304 precision,
305 log_norm_const,
306 })
307 }
308}
309
310impl TargetDistribution for MultivariateNormalTarget {
311 fn log_density(&self, x: &Array1<f64>) -> f64 {
312 let diff = x - &self.mean;
313 let quad_form = diff.dot(&self.precision.dot(&diff));
314 self.log_norm_const - 0.5 * quad_form
315 }
316
317 fn dim(&self) -> usize {
318 self.mean.len()
319 }
320}
321
322pub struct CustomTarget<F> {
324 pub log_density_fn: F,
326 pub dim: usize,
328}
329
330impl<F> CustomTarget<F> {
331 pub fn new(dim: usize, log_densityfn: F) -> Result<Self> {
333 check_positive(dim, "dim")?;
334 Ok(Self {
335 log_density_fn: log_densityfn,
336 dim,
337 })
338 }
339}
340
341impl<F> TargetDistribution for CustomTarget<F>
342where
343 F: Fn(&Array1<f64>) -> f64 + Send + Sync,
344{
345 fn log_density(&self, x: &Array1<f64>) -> f64 {
346 (self.log_density_fn)(x)
347 }
348
349 fn dim(&self) -> usize {
350 self.dim
351 }
352}