1use rand::{Rng, SeedableRng};
4use rand_chacha::ChaCha8Rng;
5use rand_distr::{Distribution, Exp, Gamma, Poisson};
6
7#[derive(Debug)]
9pub struct SimRng {
10 rng: ChaCha8Rng,
11}
12
13impl SimRng {
14 pub fn new(seed: u64) -> Self {
16 Self {
17 rng: ChaCha8Rng::seed_from_u64(seed),
18 }
19 }
20
21 pub fn gen_range<T, R>(&mut self, range: R) -> T
23 where
24 T: rand::distributions::uniform::SampleUniform,
25 R: rand::distributions::uniform::SampleRange<T>,
26 {
27 self.rng.gen_range(range)
28 }
29
30 pub fn gen_f64(&mut self) -> f64 {
32 self.rng.gen()
33 }
34
35 pub fn gen_bool(&mut self, p: f64) -> bool {
37 self.rng.gen_bool(p)
38 }
39
40 pub fn exponential(&mut self, rate: f64) -> f64 {
45 if rate <= 0.0 {
46 return f64::INFINITY;
47 }
48 let exp = Exp::new(rate).unwrap();
49 exp.sample(&mut self.rng)
50 }
51
52 pub fn poisson(&mut self, lambda: f64) -> u32 {
56 if lambda <= 0.0 {
57 return 0;
58 }
59 let pois = Poisson::new(lambda).unwrap();
60 pois.sample(&mut self.rng) as u32
61 }
62
63 pub fn negbin(&mut self, mean: f64, dispersion: f64) -> u32 {
68 if mean <= 0.0 || dispersion <= 0.0 {
69 return 1;
70 }
71
72 let r = dispersion;
75 let p = r / (r + mean);
76 let gamma_shape = r;
77 let gamma_scale = (1.0 - p) / p;
78
79 let gamma = Gamma::new(gamma_shape, gamma_scale).unwrap();
80 let lambda = gamma.sample(&mut self.rng);
81
82 if lambda <= 0.0 {
83 return 1;
84 }
85
86 let pois = Poisson::new(lambda).unwrap();
87 let value = pois.sample(&mut self.rng) as u32;
88
89 value.max(1)
91 }
92
93 pub fn zipf(&mut self, n: usize, alpha: f64) -> usize {
97 if n == 0 {
98 return 0;
99 }
100 if n == 1 {
101 return 0;
102 }
103
104 let mut h_sum = 0.0;
106 for k in 1..=n {
107 h_sum += 1.0 / (k as f64).powf(alpha);
108 }
109
110 let u: f64 = self.rng.gen();
112 let target = u * h_sum;
113
114 let mut cumsum = 0.0;
115 for k in 1..=n {
116 cumsum += 1.0 / (k as f64).powf(alpha);
117 if cumsum >= target {
118 return k - 1; }
120 }
121
122 n - 1
123 }
124
125 pub fn normal(&mut self, mean: f64, stddev: f64) -> f64 {
127 use rand_distr::Normal;
128 let normal = Normal::new(mean, stddev).unwrap();
129 normal.sample(&mut self.rng)
130 }
131
132 pub fn lognormal(&mut self, mean: f64, stddev: f64) -> f64 {
134 use rand_distr::LogNormal;
135 let variance = stddev * stddev;
137 let mu = (mean * mean / (mean * mean + variance).sqrt()).ln();
138 let sigma = (1.0 + variance / (mean * mean)).ln().sqrt();
139 let lognormal = LogNormal::new(mu, sigma).unwrap();
140 lognormal.sample(&mut self.rng)
141 }
142
143 pub fn uniform(&mut self, min: f64, max: f64) -> f64 {
145 self.rng.gen_range(min..max)
146 }
147
148 pub fn choose<'a, T>(&mut self, slice: &'a [T]) -> Option<&'a T> {
150 if slice.is_empty() {
151 None
152 } else {
153 let idx = self.gen_range(0..slice.len());
154 Some(&slice[idx])
155 }
156 }
157
158 pub fn shuffle<T>(&mut self, slice: &mut [T]) {
160 use rand::seq::SliceRandom;
161 slice.shuffle(&mut self.rng);
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn test_determinism() {
171 let mut rng1 = SimRng::new(42);
172 let mut rng2 = SimRng::new(42);
173
174 for _ in 0..100 {
175 assert_eq!(rng1.gen_f64(), rng2.gen_f64());
176 }
177 }
178
179 #[test]
180 fn test_exponential() {
181 let mut rng = SimRng::new(42);
182
183 let rate = 2.0;
185 let expected_mean = 1.0 / rate;
186 let samples: Vec<f64> = (0..10000).map(|_| rng.exponential(rate)).collect();
187 let actual_mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
188
189 assert!((actual_mean - expected_mean).abs() < 0.1);
191 }
192
193 #[test]
194 fn test_negbin() {
195 let mut rng = SimRng::new(42);
196
197 for _ in 0..100 {
199 let value = rng.negbin(2.2, 1.3);
200 assert!(value >= 1);
201 }
202 }
203
204 #[test]
205 fn test_zipf() {
206 let mut rng = SimRng::new(42);
207 let n = 100;
208
209 let mut counts = vec![0u32; n];
211 for _ in 0..10000 {
212 let idx = rng.zipf(n, 1.0);
213 assert!(idx < n);
214 counts[idx] += 1;
215 }
216
217 assert!(counts[0] > counts[n - 1]);
219 }
220}