1use rand::{Rng, SeedableRng};
6use rand_chacha::ChaCha20Rng;
7
8#[derive(Debug)]
15pub struct DeterministicRng {
16 rng: ChaCha20Rng,
17 seed: u64,
18 fork_counter: u64,
20}
21
22impl DeterministicRng {
23 #[must_use]
32 pub fn new(seed: u64) -> Self {
33 let rng = ChaCha20Rng::seed_from_u64(seed);
34
35 let result = Self {
37 rng,
38 seed,
39 fork_counter: 0,
40 };
41 assert_eq!(result.seed, seed, "seed must be stored");
42 result
43 }
44
45 #[must_use]
47 pub fn seed(&self) -> u64 {
48 self.seed
49 }
50
51 pub fn next_float(&mut self) -> f64 {
53 let value = self.rng.gen::<f64>();
54
55 assert!((0.0..1.0).contains(&value), "float must be in [0, 1)");
57 value
58 }
59
60 pub fn next_u64(&mut self) -> u64 {
62 self.rng.gen()
63 }
64
65 pub fn next_int(&mut self, min: i64, max: i64) -> i64 {
70 assert!(min <= max, "min ({}) must be <= max ({})", min, max);
72
73 let value = self.rng.gen_range(min..=max);
74
75 assert!(value >= min && value <= max, "value must be in range");
77 value
78 }
79
80 pub fn next_usize(&mut self, min: usize, max: usize) -> usize {
85 assert!(min <= max, "min ({}) must be <= max ({})", min, max);
87
88 let value = self.rng.gen_range(min..=max);
89
90 assert!(value >= min && value <= max, "value must be in range");
92 value
93 }
94
95 pub fn next_bool(&mut self, probability: f64) -> bool {
100 assert!(
102 (0.0..=1.0).contains(&probability),
103 "probability must be in [0, 1], got {}",
104 probability
105 );
106
107 self.next_float() < probability
108 }
109
110 pub fn choose<'a, T>(&mut self, items: &'a [T]) -> &'a T {
115 assert!(!items.is_empty(), "cannot choose from empty slice");
117
118 let index = self.next_usize(0, items.len() - 1);
119 &items[index]
120 }
121
122 pub fn shuffle<T>(&mut self, items: &mut [T]) {
124 for i in (1..items.len()).rev() {
126 let j = self.next_usize(0, i);
127 items.swap(i, j);
128 }
129 }
130
131 pub fn fork(&mut self) -> Self {
144 let fork_seed = self.seed.wrapping_add(
147 self.fork_counter
148 .wrapping_add(1)
149 .wrapping_mul(0x9E3779B97F4A7C15),
150 );
151 self.fork_counter += 1;
152
153 Self::new(fork_seed)
155 }
156
157 pub fn next_bytes(&mut self, len: usize) -> Vec<u8> {
159 assert!(len <= 1_000_000, "len must be <= 1MB");
161
162 let mut bytes = vec![0u8; len];
163 self.rng.fill(&mut bytes[..]);
164
165 assert_eq!(bytes.len(), len, "must generate requested bytes");
167 bytes
168 }
169}
170
171impl Clone for DeterministicRng {
172 fn clone(&self) -> Self {
173 Self {
174 rng: self.rng.clone(),
175 seed: self.seed,
176 fork_counter: self.fork_counter,
177 }
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_same_seed_same_sequence() {
187 let mut rng1 = DeterministicRng::new(12345);
188 let mut rng2 = DeterministicRng::new(12345);
189
190 for _ in 0..100 {
191 assert_eq!(rng1.next_float(), rng2.next_float());
192 }
193 }
194
195 #[test]
196 fn test_different_seeds_different_sequence() {
197 let mut rng1 = DeterministicRng::new(12345);
198 let mut rng2 = DeterministicRng::new(54321);
199
200 let differs = (0..10).any(|_| rng1.next_float() != rng2.next_float());
201 assert!(
202 differs,
203 "different seeds should produce different sequences"
204 );
205 }
206
207 #[test]
208 fn test_next_int_bounds() {
209 let mut rng = DeterministicRng::new(42);
210
211 for _ in 0..100 {
212 let val = rng.next_int(5, 10);
213 assert!((5..=10).contains(&val));
214 }
215 }
216
217 #[test]
218 fn test_next_bool_always_false() {
219 let mut rng = DeterministicRng::new(42);
220
221 for _ in 0..100 {
222 assert!(!rng.next_bool(0.0));
223 }
224 }
225
226 #[test]
227 fn test_next_bool_always_true() {
228 let mut rng = DeterministicRng::new(42);
229
230 for _ in 0..100 {
231 assert!(rng.next_bool(1.0));
232 }
233 }
234
235 #[test]
236 fn test_fork_independence() {
237 let mut rng = DeterministicRng::new(42);
238
239 let mut fork1 = rng.fork();
240 let mut fork2 = rng.fork();
241
242 assert_ne!(
244 fork1.seed(),
245 fork2.seed(),
246 "forks should have different seeds"
247 );
248
249 let fork1_vals: Vec<f64> = (0..5).map(|_| fork1.next_float()).collect();
251 let fork2_vals: Vec<f64> = (0..5).map(|_| fork2.next_float()).collect();
252
253 assert_ne!(
254 fork1_vals, fork2_vals,
255 "forks should have different sequences"
256 );
257
258 let _ = rng.next_float();
260 }
261
262 #[test]
263 fn test_choose() {
264 let mut rng = DeterministicRng::new(42);
265 let items = vec![1, 2, 3, 4, 5];
266
267 for _ in 0..100 {
268 let chosen = rng.choose(&items);
269 assert!(items.contains(chosen));
270 }
271 }
272
273 #[test]
274 fn test_shuffle() {
275 let mut rng = DeterministicRng::new(42);
276 let mut items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
277 let original = items.clone();
278
279 rng.shuffle(&mut items);
280
281 assert_ne!(items, original, "shuffle should change order");
283 items.sort();
285 assert_eq!(items, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
286 }
287
288 #[test]
289 fn test_next_bytes() {
290 let mut rng = DeterministicRng::new(42);
291 let bytes = rng.next_bytes(32);
292 assert_eq!(bytes.len(), 32);
293 }
294
295 #[test]
296 #[should_panic(expected = "min (10) must be <= max (5)")]
297 fn test_next_int_invalid_range() {
298 let mut rng = DeterministicRng::new(42);
299 rng.next_int(10, 5);
300 }
301
302 #[test]
303 #[should_panic(expected = "probability must be in [0, 1]")]
304 fn test_next_bool_invalid_probability() {
305 let mut rng = DeterministicRng::new(42);
306 rng.next_bool(1.5);
307 }
308
309 #[test]
310 #[should_panic(expected = "cannot choose from empty slice")]
311 fn test_choose_empty() {
312 let mut rng = DeterministicRng::new(42);
313 let items: Vec<i32> = vec![];
314 rng.choose(&items);
315 }
316}