simd_kernels/kernels/scientific/distributions/shared/
sampler.rs1use minarrow::Vec64;
10use rand::rngs::ThreadRng;
11use rand::{Rng, rng};
12use std::f64::consts::PI;
13
14pub struct Sampler {
16 rng: ThreadRng,
17}
18
19impl Sampler {
20 #[inline]
22 pub fn new() -> Self {
23 Sampler { rng: rng() }
24 }
25
26 #[inline]
28 pub fn sample_standard_normal(&mut self) -> f64 {
29 sample_standard_normal(&mut self.rng)
30 }
31
32 #[inline]
34 fn sample_gamma(&mut self, shape: f64, scale: f64) -> f64 {
35 sample_gamma(&mut self.rng, shape, scale)
36 }
37
38 #[inline]
40 pub fn gamma(&mut self, shape: f64, scale: f64) -> f64 {
41 self.sample_gamma(shape, scale)
42 }
43
44 #[inline]
46 pub fn chi2(&mut self, df: f64) -> f64 {
47 assert!(df.is_finite() && df > 0.0, "df must be finite and > 0");
48 self.gamma(df * 0.5, 2.0)
49 }
50
51 #[inline]
53 pub fn standard_normal_vec(&mut self, dim: usize) -> Vec64<f64> {
54 let mut v = Vec64::with_capacity(dim);
55 for _ in 0..dim {
56 v.push(self.sample_standard_normal());
57 }
58 v
59 }
60
61 #[inline]
65 pub fn dirichlet(&mut self, alpha: &[f64]) -> Vec64<f64> {
66 assert!(!alpha.is_empty(), "alpha must be non-empty");
67 assert!(
68 alpha.iter().all(|&a| a.is_finite() && a > 0.0),
69 "all alpha entries must be finite and > 0"
70 );
71
72 let mut draw = Vec64::with_capacity(alpha.len());
73 let mut sum = 0.0;
74 for &a in alpha {
75 let x = self.gamma(a, 1.0);
76 sum += x;
77 draw.push(x);
78 }
79 draw.iter_mut().for_each(|v| *v /= sum);
81 draw
82 }
83}
84
85#[inline]
88pub fn sample_standard_normal<R: Rng + ?Sized>(rng: &mut R) -> f64 {
89 let u1: f64 = rng.random::<f64>().max(f64::MIN_POSITIVE); let u2: f64 = rng.random::<f64>();
92 let r = (-2.0 * u1.ln()).sqrt();
93 r * (2.0 * PI * u2).cos()
94}
95
96#[inline]
99pub fn sample_gamma<R: Rng + ?Sized>(rng: &mut R, shape: f64, scale: f64) -> f64 {
100 assert!(
101 shape.is_finite() && shape > 0.0,
102 "shape must be finite and > 0"
103 );
104 assert!(
105 scale.is_finite() && scale > 0.0,
106 "scale must be finite and > 0"
107 );
108
109 if shape < 1.0 {
111 let u: f64 = rng.random::<f64>();
112 return sample_gamma(rng, shape + 1.0, scale) * u.powf(1.0 / shape);
113 }
114
115 let d = shape - 1.0 / 3.0;
116 let c = 1.0 / (9.0 * d).sqrt(); loop {
119 let x = sample_standard_normal(rng);
120 let one_plus_cx = 1.0 + c * x;
121 if one_plus_cx <= 0.0 {
122 continue;
123 }
124 let v = one_plus_cx * one_plus_cx * one_plus_cx; let u: f64 = rng.random::<f64>();
126
127 if u < 1.0 - 0.0331 * (x * x) * (x * x) {
129 return d * v * scale;
130 }
131 if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
133 return d * v * scale;
134 }
135 }
136}