1use crate::random::core::Random;
29use ::ndarray::{Array1, ArrayD, Dimension, IxDyn};
30use rand::Rng;
31use rand_distr::Distribution;
32use std::fmt;
33
34#[derive(Debug, Clone)]
36pub enum UnifiedDistributionError {
37 InvalidParameter(String),
39 ConstructionFailed(String),
41 Other(String),
43}
44
45impl fmt::Display for UnifiedDistributionError {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 Self::InvalidParameter(msg) => write!(f, "Invalid parameter: {}", msg),
49 Self::ConstructionFailed(msg) => write!(f, "Construction failed: {}", msg),
50 Self::Other(msg) => write!(f, "{}", msg),
51 }
52 }
53}
54
55impl std::error::Error for UnifiedDistributionError {}
56
57impl From<rand_distr::NormalError> for UnifiedDistributionError {
59 fn from(e: rand_distr::NormalError) -> Self {
60 Self::ConstructionFailed(format!("Normal distribution error: {:?}", e))
61 }
62}
63
64impl From<rand_distr::BetaError> for UnifiedDistributionError {
65 fn from(e: rand_distr::BetaError) -> Self {
66 Self::ConstructionFailed(format!("Beta distribution error: {:?}", e))
67 }
68}
69
70impl From<rand_distr::CauchyError> for UnifiedDistributionError {
71 fn from(e: rand_distr::CauchyError) -> Self {
72 Self::ConstructionFailed(format!("Cauchy distribution error: {:?}", e))
73 }
74}
75
76impl From<rand_distr::ChiSquaredError> for UnifiedDistributionError {
77 fn from(e: rand_distr::ChiSquaredError) -> Self {
78 Self::ConstructionFailed(format!("ChiSquared distribution error: {:?}", e))
79 }
80}
81
82impl From<rand_distr::FisherFError> for UnifiedDistributionError {
83 fn from(e: rand_distr::FisherFError) -> Self {
84 Self::ConstructionFailed(format!("FisherF distribution error: {:?}", e))
85 }
86}
87
88impl From<rand_distr::ExpError> for UnifiedDistributionError {
89 fn from(e: rand_distr::ExpError) -> Self {
90 Self::ConstructionFailed(format!("Exponential distribution error: {:?}", e))
91 }
92}
93
94impl From<rand_distr::GammaError> for UnifiedDistributionError {
95 fn from(e: rand_distr::GammaError) -> Self {
96 Self::ConstructionFailed(format!("Gamma distribution error: {:?}", e))
97 }
98}
99
100impl From<rand_distr::WeibullError> for UnifiedDistributionError {
101 fn from(e: rand_distr::WeibullError) -> Self {
102 Self::ConstructionFailed(format!("Weibull distribution error: {:?}", e))
103 }
104}
105
106impl From<rand_distr::BinomialError> for UnifiedDistributionError {
107 fn from(e: rand_distr::BinomialError) -> Self {
108 Self::ConstructionFailed(format!("Binomial distribution error: {:?}", e))
109 }
110}
111
112impl From<rand_distr::PoissonError> for UnifiedDistributionError {
113 fn from(e: rand_distr::PoissonError) -> Self {
114 Self::ConstructionFailed(format!("Poisson distribution error: {:?}", e))
115 }
116}
117
118impl From<std::io::Error> for UnifiedDistributionError {
119 fn from(e: std::io::Error) -> Self {
120 Self::Other(e.to_string())
121 }
122}
123
124pub trait UnifiedDistribution<T> {
126 fn sample_unified<R: Rng>(&self, rng: &mut Random<R>) -> T;
128
129 fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, n: usize) -> Vec<T>;
131
132 fn sample_array<R: Rng>(&self, rng: &mut Random<R>, shape: IxDyn) -> ArrayD<T>
134 where
135 T: Clone;
136
137 fn parameters_string(&self) -> String;
139
140 fn validate(&self) -> Result<(), UnifiedDistributionError>;
142}
143
144macro_rules! impl_unified_distribution {
146 ($name:ident, $inner:ty, $output:ty, $params:expr) => {
147 #[derive(Debug, Clone)]
148 pub struct $name {
149 inner: $inner,
150 }
151
152 impl $name {
153 pub fn inner(&self) -> &$inner {
155 &self.inner
156 }
157
158 pub fn inner_mut(&mut self) -> &mut $inner {
160 &mut self.inner
161 }
162 }
163
164 impl UnifiedDistribution<$output> for $name {
165 fn sample_unified<R: Rng>(&self, rng: &mut Random<R>) -> $output {
166 rng.sample(&self.inner)
167 }
168
169 fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, n: usize) -> Vec<$output> {
170 (0..n).map(|_| self.sample_unified(rng)).collect()
171 }
172
173 fn sample_array<R: Rng>(&self, rng: &mut Random<R>, shape: IxDyn) -> ArrayD<$output>
174 where
175 $output: Clone,
176 {
177 let size = shape.size();
178 let values = self.sample_vec(rng, size);
179 ArrayD::from_shape_vec(shape, values).expect("Operation failed")
180 }
181
182 fn parameters_string(&self) -> String {
183 $params(&self.inner)
184 }
185
186 fn validate(&self) -> Result<(), UnifiedDistributionError> {
187 Ok(())
189 }
190 }
191
192 impl Distribution<$output> for $name {
193 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> $output {
194 self.inner.sample(rng)
195 }
196 }
197 };
198}
199
200impl_unified_distribution!(
203 UnifiedNormal,
204 rand_distr::Normal<f64>,
205 f64,
206 |d: &rand_distr::Normal<f64>| format!("Normal(mean={}, std={})", d.mean(), d.std_dev())
207);
208
209impl UnifiedNormal {
210 pub fn new(mean: f64, std_dev: f64) -> Result<Self, UnifiedDistributionError> {
211 Ok(Self {
212 inner: rand_distr::Normal::new(mean, std_dev)?,
213 })
214 }
215
216 pub fn mean(&self) -> f64 {
217 self.inner.mean()
218 }
219
220 pub fn std_dev(&self) -> f64 {
221 self.inner.std_dev()
222 }
223}
224
225impl_unified_distribution!(
226 UnifiedBeta,
227 rand_distr::Beta<f64>,
228 f64,
229 |_: &rand_distr::Beta<f64>| "Beta(alpha, beta)".to_string()
230);
231
232impl UnifiedBeta {
233 pub fn new(alpha: f64, beta: f64) -> Result<Self, UnifiedDistributionError> {
234 Ok(Self {
235 inner: rand_distr::Beta::new(alpha, beta)?,
236 })
237 }
238}
239
240impl_unified_distribution!(
241 UnifiedCauchy,
242 rand_distr::Cauchy<f64>,
243 f64,
244 |_: &rand_distr::Cauchy<f64>| "Cauchy(median, scale)".to_string()
245);
246
247impl UnifiedCauchy {
248 pub fn new(median: f64, scale: f64) -> Result<Self, UnifiedDistributionError> {
249 Ok(Self {
250 inner: rand_distr::Cauchy::new(median, scale)?,
251 })
252 }
253}
254
255impl_unified_distribution!(
256 UnifiedChiSquared,
257 rand_distr::ChiSquared<f64>,
258 f64,
259 |_: &rand_distr::ChiSquared<f64>| "ChiSquared(k)".to_string()
260);
261
262impl UnifiedChiSquared {
263 pub fn new(k: f64) -> Result<Self, UnifiedDistributionError> {
264 Ok(Self {
265 inner: rand_distr::ChiSquared::new(k)?,
266 })
267 }
268}
269
270impl_unified_distribution!(
271 UnifiedFisherF,
272 rand_distr::FisherF<f64>,
273 f64,
274 |_: &rand_distr::FisherF<f64>| "FisherF(m, n)".to_string()
275);
276
277impl UnifiedFisherF {
278 pub fn new(m: f64, n: f64) -> Result<Self, UnifiedDistributionError> {
279 Ok(Self {
280 inner: rand_distr::FisherF::new(m, n)?,
281 })
282 }
283}
284
285impl_unified_distribution!(
286 UnifiedStudentT,
287 rand_distr::StudentT<f64>,
288 f64,
289 |_: &rand_distr::StudentT<f64>| "StudentT(n)".to_string()
290);
291
292impl UnifiedStudentT {
293 pub fn new(n: f64) -> Result<Self, UnifiedDistributionError> {
294 Ok(Self {
295 inner: rand_distr::StudentT::new(n)?,
296 })
297 }
298}
299
300impl_unified_distribution!(
301 UnifiedLogNormal,
302 rand_distr::LogNormal<f64>,
303 f64,
304 |_: &rand_distr::LogNormal<f64>| "LogNormal(mean, std)".to_string()
305);
306
307impl UnifiedLogNormal {
308 pub fn new(mean: f64, std_dev: f64) -> Result<Self, UnifiedDistributionError> {
309 Ok(Self {
310 inner: rand_distr::LogNormal::new(mean, std_dev)?,
311 })
312 }
313}
314
315impl_unified_distribution!(
316 UnifiedWeibull,
317 rand_distr::Weibull<f64>,
318 f64,
319 |_: &rand_distr::Weibull<f64>| "Weibull(scale, shape)".to_string()
320);
321
322impl UnifiedWeibull {
323 pub fn new(scale: f64, shape: f64) -> Result<Self, UnifiedDistributionError> {
324 Ok(Self {
325 inner: rand_distr::Weibull::new(scale, shape)?,
326 })
327 }
328}
329
330impl_unified_distribution!(
331 UnifiedGamma,
332 rand_distr::Gamma<f64>,
333 f64,
334 |_: &rand_distr::Gamma<f64>| "Gamma(shape, scale)".to_string()
335);
336
337impl UnifiedGamma {
338 pub fn new(shape: f64, scale: f64) -> Result<Self, UnifiedDistributionError> {
339 Ok(Self {
340 inner: rand_distr::Gamma::new(shape, scale)?,
341 })
342 }
343}
344
345impl_unified_distribution!(
346 UnifiedExp,
347 rand_distr::Exp<f64>,
348 f64,
349 |_: &rand_distr::Exp<f64>| "Exp(lambda)".to_string()
350);
351
352impl UnifiedExp {
353 pub fn new(lambda: f64) -> Result<Self, UnifiedDistributionError> {
354 Ok(Self {
355 inner: rand_distr::Exp::new(lambda)?,
356 })
357 }
358}
359
360impl_unified_distribution!(
363 UnifiedBinomial,
364 rand_distr::Binomial,
365 u64,
366 |_: &rand_distr::Binomial| "Binomial(n, p)".to_string()
367);
368
369impl UnifiedBinomial {
370 pub fn new(n: u64, p: f64) -> Result<Self, UnifiedDistributionError> {
371 Ok(Self {
372 inner: rand_distr::Binomial::new(n, p)?,
373 })
374 }
375}
376
377impl_unified_distribution!(
379 UnifiedPoisson,
380 rand_distr::Poisson<f64>,
381 f64,
382 |_: &rand_distr::Poisson<f64>| "Poisson(lambda)".to_string()
383);
384
385impl UnifiedPoisson {
386 pub fn new(lambda: f64) -> Result<Self, UnifiedDistributionError> {
387 Ok(Self {
388 inner: rand_distr::Poisson::new(lambda)?,
389 })
390 }
391}
392
393#[derive(Debug, Clone)]
400pub struct UnifiedDirichlet {
401 inner: crate::random::distributions::Dirichlet,
402}
403
404impl UnifiedDirichlet {
405 pub fn new(alpha: Vec<f64>) -> Result<Self, UnifiedDistributionError> {
406 Ok(Self {
407 inner: crate::random::distributions::Dirichlet::new(alpha).map_err(|e| {
408 UnifiedDistributionError::ConstructionFailed(format!("Dirichlet error: {}", e))
409 })?,
410 })
411 }
412
413 pub fn sample<R: Rng>(&self, rng: &mut Random<R>) -> Vec<f64> {
415 self.inner.sample(rng)
416 }
417
418 pub fn sample_array<R: Rng>(&self, rng: &mut Random<R>) -> Array1<f64> {
420 Array1::from_vec(self.sample(rng))
421 }
422
423 pub fn sample_multiple<R: Rng>(&self, rng: &mut Random<R>, n: usize) -> Vec<Vec<f64>> {
425 (0..n).map(|_| self.sample(rng)).collect()
426 }
427
428 pub fn alphas(&self) -> &[f64] {
430 self.inner.alphas()
431 }
432}
433
434impl Distribution<Vec<f64>> for UnifiedDirichlet {
435 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> {
436 use rand_distr::Gamma;
439 let gamma_samples: Vec<f64> = self
440 .inner
441 .alphas()
442 .iter()
443 .map(|&alpha| {
444 let gamma = Gamma::new(alpha, 1.0).expect("Operation failed");
445 rng.sample(gamma)
446 })
447 .collect();
448
449 let sum: f64 = gamma_samples.iter().sum();
451 gamma_samples.into_iter().map(|x| x / sum).collect()
452 }
453}
454
455impl UnifiedDistribution<Vec<f64>> for UnifiedDirichlet {
456 fn sample_unified<R: Rng>(&self, rng: &mut Random<R>) -> Vec<f64> {
457 self.sample(rng)
458 }
459
460 fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, n: usize) -> Vec<Vec<f64>> {
461 self.sample_multiple(rng, n)
462 }
463
464 fn sample_array<R: Rng>(&self, rng: &mut Random<R>, shape: IxDyn) -> ArrayD<Vec<f64>> {
465 let size = shape.size();
466 let values = self.sample_vec(rng, size);
467 ArrayD::from_shape_vec(shape, values).expect("Operation failed")
468 }
469
470 fn parameters_string(&self) -> String {
471 format!("Dirichlet(alpha=[{} values])", self.alphas().len())
472 }
473
474 fn validate(&self) -> Result<(), UnifiedDistributionError> {
475 Ok(())
476 }
477}
478
479pub mod defaults {
481 use super::*;
482
483 pub fn standard_normal() -> UnifiedNormal {
485 UnifiedNormal::new(0.0, 1.0).expect("Operation failed")
486 }
487
488 pub fn uniform_01() -> rand_distr::Uniform<f64> {
490 rand_distr::Uniform::new(0.0, 1.0).expect("Operation failed")
491 }
492
493 pub fn standard_exponential() -> UnifiedExp {
495 UnifiedExp::new(1.0).expect("Operation failed")
496 }
497
498 pub fn standard_gamma() -> UnifiedGamma {
500 UnifiedGamma::new(1.0, 1.0).expect("Operation failed")
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507 use crate::random::thread_rng;
508
509 #[test]
510 fn test_unified_normal() {
511 let dist = UnifiedNormal::new(0.0, 1.0).expect("Operation failed");
512 let mut rng = thread_rng();
513
514 let sample = dist.sample_unified(&mut rng);
515 assert!(sample.is_finite());
516
517 let samples = dist.sample_vec(&mut rng, 100);
518 assert_eq!(samples.len(), 100);
519 }
520
521 #[test]
522 fn test_unified_beta() {
523 let dist = UnifiedBeta::new(2.0, 5.0).expect("Operation failed");
524 let mut rng = thread_rng();
525
526 let sample = dist.sample_unified(&mut rng);
527 assert!(sample >= 0.0 && sample <= 1.0);
528 }
529
530 #[test]
531 fn test_unified_poisson() {
532 let dist = UnifiedPoisson::new(5.0).expect("Operation failed");
533 let mut rng = thread_rng();
534
535 let sample = dist.sample_unified(&mut rng);
536 assert!(sample >= 0.0);
537 }
538
539 #[test]
540 fn test_unified_dirichlet() {
541 let dist = UnifiedDirichlet::new(vec![1.0, 2.0, 3.0]).expect("Operation failed");
542 let mut rng = thread_rng();
543
544 let sample = dist.sample(&mut rng);
545 assert_eq!(sample.len(), 3);
546
547 let sum: f64 = sample.iter().sum();
548 assert!((sum - 1.0).abs() < 1e-10);
549 }
550
551 #[test]
552 fn test_distribution_trait() {
553 let dist = UnifiedNormal::new(0.0, 1.0).expect("Operation failed");
554 let mut rng = rand::thread_rng();
555
556 let sample: f64 = rng.sample(&dist);
558 assert!(sample.is_finite());
559 }
560}