1use crate::random::core::Random;
8use ::ndarray::{Array, Array1, Array2};
9use rand::Rng;
10use rand_distr::{Distribution, Gamma, Normal, Uniform};
11use std::f64::consts::PI;
12
13#[derive(Debug, Clone)]
15pub struct Beta {
16 alpha: f64,
17 beta: f64,
18 gamma_alpha: Gamma<f64>,
19 gamma_beta: Gamma<f64>,
20}
21
22impl Beta {
23 pub fn new(alpha: f64, beta: f64) -> Result<Self, String> {
25 if alpha <= 0.0 || beta <= 0.0 {
26 return Err("Alpha and beta parameters must be positive".to_string());
27 }
28
29 let gamma_alpha = Gamma::new(alpha, 1.0).expect("Operation failed");
30 let gamma_beta = Gamma::new(beta, 1.0).expect("Operation failed");
31
32 Ok(Self {
33 alpha,
34 beta,
35 gamma_alpha,
36 gamma_beta,
37 })
38 }
39
40 pub fn sample<R: Rng>(&self, rng: &mut Random<R>) -> f64 {
42 let x = rng.sample(self.gamma_alpha);
43 let y = rng.sample(self.gamma_beta);
44 x / (x + y)
45 }
46
47 pub fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, count: usize) -> Vec<f64> {
49 (0..count).map(|_| self.sample(rng)).collect()
50 }
51
52 pub fn sample_array<R: Rng>(&self, rng: &mut Random<R>, count: usize) -> Array1<f64> {
54 let samples = self.sample_vec(rng, count);
55 Array1::from_vec(samples)
56 }
57
58 pub fn parameters(&self) -> (f64, f64) {
60 (self.alpha, self.beta)
61 }
62
63 pub fn mean(&self) -> f64 {
65 self.alpha / (self.alpha + self.beta)
66 }
67
68 pub fn variance(&self) -> f64 {
70 let ab_sum = self.alpha + self.beta;
71 (self.alpha * self.beta) / (ab_sum * ab_sum * (ab_sum + 1.0))
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct Categorical {
78 weights: Vec<f64>,
79 cumulative: Vec<f64>,
80}
81
82impl Categorical {
83 pub fn new(weights: Vec<f64>) -> Result<Self, String> {
85 if weights.is_empty() {
86 return Err("Weights vector cannot be empty".to_string());
87 }
88
89 if weights.iter().any(|&w| w < 0.0) {
90 return Err("All weights must be non-negative".to_string());
91 }
92
93 let total: f64 = weights.iter().sum();
94 if total <= 0.0 {
95 return Err("Sum of weights must be positive".to_string());
96 }
97
98 let mut cumulative = Vec::with_capacity(weights.len());
100 let mut sum = 0.0;
101 for &weight in &weights {
102 sum += weight / total;
103 cumulative.push(sum);
104 }
105
106 Ok(Self {
107 weights,
108 cumulative,
109 })
110 }
111
112 pub fn sample<R: Rng>(&self, rng: &mut Random<R>) -> usize {
114 let u = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
115
116 match self
118 .cumulative
119 .binary_search_by(|&x| x.partial_cmp(&u).expect("Operation failed"))
120 {
121 Ok(idx) => idx,
122 Err(idx) => idx.min(self.cumulative.len() - 1),
123 }
124 }
125
126 pub fn len(&self) -> usize {
128 self.weights.len()
129 }
130
131 pub fn is_empty(&self) -> bool {
133 self.weights.is_empty()
134 }
135
136 pub fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, count: usize) -> Vec<usize> {
138 (0..count).map(|_| self.sample(rng)).collect()
139 }
140
141 pub fn probability(&self, i: usize) -> Option<f64> {
143 if i < self.weights.len() {
144 let total: f64 = self.weights.iter().sum();
145 Some(self.weights[i] / total)
146 } else {
147 None
148 }
149 }
150}
151
152#[derive(Debug, Clone)]
154pub struct WeightedChoice<T> {
155 items: Vec<T>,
156 categorical: Categorical,
157}
158
159impl<T: Clone> WeightedChoice<T> {
160 pub fn new(items: Vec<T>, weights: Vec<f64>) -> Result<Self, String> {
162 if items.len() != weights.len() {
163 return Err("Items and weights must have the same length".to_string());
164 }
165
166 let categorical = Categorical::new(weights)?;
167
168 Ok(Self { items, categorical })
169 }
170
171 pub fn sample<R: Rng>(&self, rng: &mut Random<R>) -> &T {
173 let index = self.categorical.sample(rng);
174 &self.items[index]
175 }
176
177 pub fn sample_cloned<R: Rng>(&self, rng: &mut Random<R>) -> T {
179 self.sample(rng).clone()
180 }
181
182 pub fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, count: usize) -> Vec<T> {
184 (0..count).map(|_| self.sample_cloned(rng)).collect()
185 }
186
187 pub fn len(&self) -> usize {
189 self.items.len()
190 }
191
192 pub fn is_empty(&self) -> bool {
194 self.items.is_empty()
195 }
196
197 pub fn items_and_probabilities(&self) -> Vec<(&T, f64)> {
199 self.items
200 .iter()
201 .enumerate()
202 .map(|(i, item)| (item, self.categorical.probability(i).unwrap_or(0.0)))
203 .collect()
204 }
205}
206
207#[derive(Debug, Clone)]
209pub struct ExponentialDist {
210 lambda: f64,
211 exponential: rand_distr::Exp<f64>,
212}
213
214impl ExponentialDist {
215 pub fn new(lambda: f64) -> Result<Self, String> {
217 if lambda <= 0.0 {
218 return Err("Lambda parameter must be positive".to_string());
219 }
220
221 let exponential = rand_distr::Exp::new(lambda).expect("Operation failed");
222
223 Ok(Self {
224 lambda,
225 exponential,
226 })
227 }
228
229 pub fn sample<R: Rng>(&self, rng: &mut Random<R>) -> f64 {
231 rng.sample(self.exponential)
232 }
233
234 pub fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, count: usize) -> Vec<f64> {
236 (0..count).map(|_| self.sample(rng)).collect()
237 }
238
239 pub fn sample_array<R: Rng>(&self, rng: &mut Random<R>, count: usize) -> Array1<f64> {
241 let samples = self.sample_vec(rng, count);
242 Array1::from_vec(samples)
243 }
244
245 pub fn lambda(&self) -> f64 {
247 self.lambda
248 }
249
250 pub fn mean(&self) -> f64 {
252 1.0 / self.lambda
253 }
254
255 pub fn variance(&self) -> f64 {
257 1.0 / (self.lambda * self.lambda)
258 }
259}
260
261#[derive(Debug, Clone)]
263pub struct GammaDist {
264 alpha: f64,
265 beta: f64,
266 gamma: Gamma<f64>,
267}
268
269impl GammaDist {
270 pub fn new(alpha: f64, beta: f64) -> Result<Self, String> {
272 if alpha <= 0.0 || beta <= 0.0 {
273 return Err("Alpha and beta parameters must be positive".to_string());
274 }
275
276 let gamma = Gamma::new(alpha, beta).expect("Operation failed");
277
278 Ok(Self { alpha, beta, gamma })
279 }
280
281 pub fn sample<R: Rng>(&self, rng: &mut Random<R>) -> f64 {
283 rng.sample(self.gamma)
284 }
285
286 pub fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, count: usize) -> Vec<f64> {
288 (0..count).map(|_| self.sample(rng)).collect()
289 }
290
291 pub fn sample_array<R: Rng>(&self, rng: &mut Random<R>, count: usize) -> Array1<f64> {
293 let samples = self.sample_vec(rng, count);
294 Array1::from_vec(samples)
295 }
296
297 pub fn parameters(&self) -> (f64, f64) {
299 (self.alpha, self.beta)
300 }
301
302 pub fn mean(&self) -> f64 {
304 self.alpha * self.beta
305 }
306
307 pub fn variance(&self) -> f64 {
309 self.alpha * self.beta * self.beta
310 }
311}
312
313#[derive(Debug, Clone)]
315pub struct VonMises {
316 mu: f64,
317 kappa: f64,
318}
319
320impl VonMises {
321 pub fn mu(mu: f64, kappa: f64) -> Result<Self, String> {
323 if kappa < 0.0 {
324 return Err("Kappa parameter must be non-negative".to_string());
325 }
326
327 Ok(Self { mu, kappa })
328 }
329
330 pub fn sample<R: Rng>(&self, rng: &mut Random<R>) -> f64 {
332 if self.kappa < 1e-6 {
333 return rng.sample(Uniform::new(0.0, 2.0 * PI).expect("Operation failed"));
335 }
336
337 let s = 0.5 / self.kappa;
339 let r = s + (1.0 + s * s).sqrt();
340
341 loop {
342 let u1 = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
343 let z = (r * u1).cos();
344 let d = z / (r + z);
345 let u2 = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
346
347 if u2 < 1.0 - d * d || u2 <= (1.0 - d) * (-self.kappa * d).exp() {
348 let u3 = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
349 let theta = if u3 > 0.5 {
350 self.mu + d.acos()
351 } else {
352 self.mu - d.acos()
353 };
354 return ((theta % (2.0 * PI)) + 2.0 * PI) % (2.0 * PI);
355 }
356 }
357 }
358
359 pub fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, count: usize) -> Vec<f64> {
361 (0..count).map(|_| self.sample(rng)).collect()
362 }
363
364 pub fn parameters(&self) -> (f64, f64) {
366 (self.mu, self.kappa)
367 }
368}
369
370#[derive(Debug, Clone)]
372pub struct MultivariateNormal {
373 mean: Vec<f64>,
374 cholesky: Array2<f64>,
375 dimension: usize,
376}
377
378impl MultivariateNormal {
379 pub fn new(mean: Vec<f64>, covariance: Vec<Vec<f64>>) -> Result<Self, String> {
381 let dimension = mean.len();
382
383 if covariance.len() != dimension {
384 return Err("Covariance matrix must be square and match mean dimension".to_string());
385 }
386
387 for row in &covariance {
388 if row.len() != dimension {
389 return Err("Covariance matrix must be square".to_string());
390 }
391 }
392
393 let mut cov_array = Array2::zeros((dimension, dimension));
395 for (i, row) in covariance.iter().enumerate() {
396 for (j, &val) in row.iter().enumerate() {
397 cov_array[[i, j]] = val;
398 }
399 }
400
401 let cholesky = Self::cholesky_decomposition(cov_array)?;
403
404 Ok(Self {
405 mean,
406 cholesky,
407 dimension,
408 })
409 }
410
411 fn cholesky_decomposition(mut a: Array2<f64>) -> Result<Array2<f64>, String> {
413 let n = a.nrows();
414 let mut l = Array2::zeros((n, n));
415
416 for i in 0..n {
417 for j in 0..=i {
418 if i == j {
419 let mut sum = 0.0;
420 for k in 0..j {
421 sum += l[[j, k]] * l[[j, k]];
422 }
423 let val = a[[j, j]] - sum;
424 if val <= 0.0 {
425 return Err("Matrix is not positive definite".to_string());
426 }
427 l[[j, j]] = val.sqrt();
428 } else {
429 let mut sum = 0.0;
430 for k in 0..j {
431 sum += l[[i, k]] * l[[j, k]];
432 }
433 l[[i, j]] = (a[[i, j]] - sum) / l[[j, j]];
434 }
435 }
436 }
437
438 Ok(l)
439 }
440
441 pub fn sample<R: Rng>(&self, rng: &mut Random<R>) -> Vec<f64> {
443 let standard_normal = Normal::new(0.0, 1.0).expect("Operation failed");
445 let z: Vec<f64> = (0..self.dimension)
446 .map(|_| rng.sample(standard_normal))
447 .collect();
448
449 let mut result = self.mean.clone();
451 for i in 0..self.dimension {
452 for j in 0..=i {
453 result[i] += self.cholesky[[i, j]] * z[j];
454 }
455 }
456
457 result
458 }
459
460 pub fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, count: usize) -> Vec<Vec<f64>> {
462 (0..count).map(|_| self.sample(rng)).collect()
463 }
464
465 pub fn sample_array<R: Rng>(&self, rng: &mut Random<R>, count: usize) -> Array2<f64> {
467 let samples = self.sample_vec(rng, count);
468 let mut array = Array2::zeros((count, self.dimension));
469
470 for (i, sample) in samples.iter().enumerate() {
471 for (j, &val) in sample.iter().enumerate() {
472 array[[i, j]] = val;
473 }
474 }
475
476 array
477 }
478
479 pub fn dimension(&self) -> usize {
481 self.dimension
482 }
483
484 pub fn mean(&self) -> &Vec<f64> {
486 &self.mean
487 }
488}
489
490#[derive(Debug, Clone)]
492pub struct Dirichlet {
493 alphas: Vec<f64>,
494 gamma_distributions: Vec<Gamma<f64>>,
495}
496
497impl Dirichlet {
498 pub fn new(alphas: Vec<f64>) -> Result<Self, String> {
500 if alphas.is_empty() {
501 return Err("Alpha parameters cannot be empty".to_string());
502 }
503
504 if alphas.iter().any(|&alpha| alpha <= 0.0) {
505 return Err("All alpha parameters must be positive".to_string());
506 }
507
508 let gamma_distributions: Result<Vec<_>, _> =
509 alphas.iter().map(|&alpha| Gamma::new(alpha, 1.0)).collect();
510
511 let gamma_distributions = gamma_distributions.expect("Operation failed");
512
513 Ok(Self {
514 alphas,
515 gamma_distributions,
516 })
517 }
518
519 pub fn sample<R: Rng>(&self, rng: &mut Random<R>) -> Vec<f64> {
521 let gamma_samples: Vec<f64> = self
523 .gamma_distributions
524 .iter()
525 .map(|gamma| rng.sample(*gamma))
526 .collect();
527
528 let sum: f64 = gamma_samples.iter().sum();
530 gamma_samples.into_iter().map(|x| x / sum).collect()
531 }
532
533 pub fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, count: usize) -> Vec<Vec<f64>> {
535 (0..count).map(|_| self.sample(rng)).collect()
536 }
537
538 pub fn sample_array<R: Rng>(&self, rng: &mut Random<R>, count: usize) -> Array2<f64> {
540 let samples = self.sample_vec(rng, count);
541 let mut array = Array2::zeros((count, self.alphas.len()));
542
543 for (i, sample) in samples.iter().enumerate() {
544 for (j, &val) in sample.iter().enumerate() {
545 array[[i, j]] = val;
546 }
547 }
548
549 array
550 }
551
552 pub fn dimension(&self) -> usize {
554 self.alphas.len()
555 }
556
557 pub fn alphas(&self) -> &Vec<f64> {
559 &self.alphas
560 }
561
562 pub fn mean(&self) -> Vec<f64> {
564 let sum: f64 = self.alphas.iter().sum();
565 self.alphas.iter().map(|&alpha| alpha / sum).collect()
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572 use crate::random::core::seeded_rng;
573 use approx::assert_abs_diff_eq;
574
575 #[test]
576 fn test_beta_distribution() {
577 let beta = Beta::new(2.0, 3.0).expect("Operation failed");
578 let mut rng = seeded_rng(42);
579
580 let sample = beta.sample(&mut rng);
581 assert!((0.0..1.0).contains(&sample));
582
583 let samples = beta.sample_vec(&mut rng, 100);
584 assert_eq!(samples.len(), 100);
585 assert!(samples.iter().all(|&x| (0.0..1.0).contains(&x)));
586
587 assert_abs_diff_eq!(beta.mean(), 0.4, epsilon = 1e-10);
589 assert!(beta.variance() > 0.0);
590
591 assert!(Beta::new(-1.0, 2.0).is_err());
593 assert!(Beta::new(2.0, -1.0).is_err());
594 }
595
596 #[test]
597 fn test_categorical_distribution() {
598 let weights = vec![0.2, 0.3, 0.5];
599 let categorical = Categorical::new(weights).expect("Operation failed");
600 let mut rng = seeded_rng(123);
601
602 let samples = categorical.sample_vec(&mut rng, 1000);
603 assert_eq!(samples.len(), 1000);
604 assert!(samples.iter().all(|&x| x < 3));
605
606 let count_0 = samples.iter().filter(|&&x| x == 0).count();
608 let count_1 = samples.iter().filter(|&&x| x == 1).count();
609 let count_2 = samples.iter().filter(|&&x| x == 2).count();
610
611 assert!(count_0 > 0);
613 assert!(count_1 > 0);
614 assert!(count_2 > 0);
615
616 assert_abs_diff_eq!(
618 categorical.probability(0).expect("Operation failed"),
619 0.2,
620 epsilon = 1e-10
621 );
622 assert_abs_diff_eq!(
623 categorical.probability(1).expect("Operation failed"),
624 0.3,
625 epsilon = 1e-10
626 );
627 assert_abs_diff_eq!(
628 categorical.probability(2).expect("Operation failed"),
629 0.5,
630 epsilon = 1e-10
631 );
632
633 assert!(Categorical::new(vec![]).is_err());
635 assert!(Categorical::new(vec![-1.0, 0.5]).is_err());
636 }
637
638 #[test]
639 fn test_multivariate_normal() {
640 let mean = vec![0.0, 0.0];
641 let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
642
643 let mvn = MultivariateNormal::new(mean, cov).expect("Operation failed");
644 let mut rng = seeded_rng(456);
645 let sample = mvn.sample(&mut rng);
646
647 assert_eq!(sample.len(), 2);
648 assert_eq!(mvn.dimension(), 2);
649
650 let samples = mvn.sample_vec(&mut rng, 10);
651 assert_eq!(samples.len(), 10);
652 assert!(samples.iter().all(|s| s.len() == 2));
653 }
654
655 #[test]
656 fn test_dirichlet_distribution() {
657 let alphas = vec![1.0, 2.0, 3.0];
658 let dirichlet = Dirichlet::new(alphas).expect("Operation failed");
659
660 let mut rng = seeded_rng(789);
661 let sample = dirichlet.sample(&mut rng);
662
663 assert_eq!(sample.len(), 3);
664 assert_abs_diff_eq!(sample.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
665 assert!(sample.iter().all(|&x| x >= 0.0));
666
667 let mean = dirichlet.mean();
668 assert_eq!(mean.len(), 3);
669 assert_abs_diff_eq!(mean.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
670 }
671
672 #[test]
673 fn test_von_mises_distribution() {
674 let von_mises = VonMises::mu(0.0, 1.0).expect("Operation failed");
675 let mut rng = seeded_rng(101112);
676
677 let samples = von_mises.sample_vec(&mut rng, 100);
678 assert_eq!(samples.len(), 100);
679 assert!(samples.iter().all(|&x| (0.0..2.0 * PI).contains(&x)));
680
681 let (mu, kappa) = von_mises.parameters();
682 assert_eq!(mu, 0.0);
683 assert_eq!(kappa, 1.0);
684 }
685
686 #[test]
687 fn test_weighted_choice() {
688 let items = vec!["A", "B", "C"];
689 let weights = vec![0.2, 0.3, 0.5];
690 let weighted_choice = WeightedChoice::new(items, weights).expect("Operation failed");
691 let mut rng = seeded_rng(131415);
692
693 let samples = weighted_choice.sample_vec(&mut rng, 100);
694 assert_eq!(samples.len(), 100);
695 assert!(samples.iter().all(|&x| ["A", "B", "C"].contains(&x)));
696
697 let items_probs = weighted_choice.items_and_probabilities();
698 assert_eq!(items_probs.len(), 3);
699
700 let items_wrong = vec!["A", "B"];
702 let weights_wrong = vec![0.2, 0.3, 0.5];
703 assert!(WeightedChoice::new(items_wrong, weights_wrong).is_err());
704 }
705}