1use rand_chacha::rand_core::SeedableRng;
49use rand_chacha::ChaCha20Rng;
50use serde::{Deserialize, Serialize};
51use sha2::{Digest, Sha256};
52
53#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56#[non_exhaustive]
57pub enum RngAlgorithm {
58 #[default]
61 ChaCha20,
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
67#[non_exhaustive]
68pub struct RngState {
69 pub algorithm: RngAlgorithm,
70 pub seed: [u8; 32],
71 pub stream: u64,
72 pub word_pos: u128,
73}
74
75impl RngState {
76 #[must_use]
79 pub fn from_seed(seed: [u8; 32]) -> Self {
80 Self {
81 algorithm: RngAlgorithm::ChaCha20,
82 seed,
83 stream: 0,
84 word_pos: 0,
85 }
86 }
87
88 #[must_use]
91 pub fn from_parts(seed: [u8; 32], stream: u64, word_pos: u128) -> Self {
92 Self {
93 algorithm: RngAlgorithm::ChaCha20,
94 seed,
95 stream,
96 word_pos,
97 }
98 }
99
100 #[must_use]
108 pub fn fork(&self, salt: &[u8]) -> Self {
109 let mut hasher = Sha256::new();
110 hasher.update(self.stream.to_le_bytes());
111 hasher.update(salt);
112 let digest = hasher.finalize();
113 let mut buf = [0u8; 8];
114 buf.copy_from_slice(&digest[..8]);
115 let mix = u64::from_le_bytes(buf);
116 Self {
117 algorithm: self.algorithm,
118 seed: self.seed,
119 stream: self.stream ^ mix,
120 word_pos: 0,
121 }
122 }
123
124 #[must_use]
129 pub fn into_chacha(self) -> ChaCha20Rng {
130 let mut rng = ChaCha20Rng::from_seed(self.seed);
131 rng.set_stream(self.stream);
132 rng.set_word_pos(self.word_pos);
133 rng
134 }
135
136 #[must_use]
147 pub fn snapshot(rng: &ChaCha20Rng, parent: &RngState) -> Self {
148 Self {
149 algorithm: parent.algorithm,
150 seed: parent.seed,
151 stream: rng.get_stream(),
152 word_pos: rng.get_word_pos(),
153 }
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use rand_chacha::rand_core::RngCore;
161
162 fn seed_bytes(b: u8) -> [u8; 32] {
163 [b; 32]
164 }
165
166 fn draw_n(rng: &mut ChaCha20Rng, n: usize) -> Vec<u8> {
167 let mut buf = vec![0u8; n];
168 rng.fill_bytes(&mut buf);
169 buf
170 }
171
172 #[test]
173 fn from_seed_defaults_stream_and_word_pos_to_zero() {
174 let s = RngState::from_seed(seed_bytes(0x42));
175 assert_eq!(s.stream, 0);
176 assert_eq!(s.word_pos, 0);
177 assert_eq!(s.algorithm, RngAlgorithm::ChaCha20);
178 assert_eq!(s.seed, seed_bytes(0x42));
179 }
180
181 #[test]
182 fn same_state_produces_identical_bytes() {
183 let s = RngState::from_parts(seed_bytes(0x42), 7, 0);
184 let mut a = s.clone().into_chacha();
185 let mut b = s.into_chacha();
186 assert_eq!(draw_n(&mut a, 1024), draw_n(&mut b, 1024));
187 }
188
189 #[test]
190 fn fork_with_same_salt_is_deterministic() {
191 let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
192 let c1 = parent.fork(b"block-0");
193 let c2 = parent.fork(b"block-0");
194 assert_eq!(c1, c2);
195 }
196
197 #[test]
198 fn fork_with_distinct_salts_produces_distinct_streams() {
199 let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
200 let c1 = parent.fork(b"block-0");
201 let c2 = parent.fork(b"block-1");
202 assert_ne!(c1.stream, c2.stream);
203 let mut r1 = c1.into_chacha();
204 let mut r2 = c2.into_chacha();
205 assert_ne!(draw_n(&mut r1, 1024), draw_n(&mut r2, 1024));
206 }
207
208 #[test]
209 fn fork_resets_child_word_pos() {
210 let parent = RngState::from_parts(seed_bytes(0x42), 100, 999);
211 let child = parent.fork(b"any");
212 assert_eq!(child.word_pos, 0);
213 assert_eq!(child.seed, parent.seed);
214 }
215
216 #[test]
217 fn fork_xors_with_mix_so_distinct_parents_produce_distinct_children_under_same_salt() {
218 let p1 = RngState::from_parts(seed_bytes(0x42), 1, 0);
219 let p2 = RngState::from_parts(seed_bytes(0x42), 2, 0);
220 let c1 = p1.fork(b"same-salt");
221 let c2 = p2.fork(b"same-salt");
222 assert_ne!(c1.stream, c2.stream);
223 }
224
225 #[test]
226 fn snapshot_round_trips_word_pos_and_stream() {
227 let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
228 let mut r = s.clone().into_chacha();
229 let _ = draw_n(&mut r, 8192);
230 let snap = RngState::snapshot(&r, &s);
231 assert_eq!(snap.word_pos, r.get_word_pos());
233 let mut original_continued = r;
236 let mut resumed = snap.into_chacha();
237 assert_eq!(
238 draw_n(&mut original_continued, 1024),
239 draw_n(&mut resumed, 1024)
240 );
241 }
242
243 #[test]
244 fn rngstate_serde_round_trip() {
245 let s = RngState::from_parts(seed_bytes(0x42), 12345, 67890);
246 let json = serde_json::to_string(&s).expect("serialize");
247 let back: RngState = serde_json::from_str(&json).expect("deserialize");
248 assert_eq!(back, s);
249 }
250
251 #[test]
254 fn rng_algorithm_default_is_chacha20() {
255 assert_eq!(RngAlgorithm::default(), RngAlgorithm::ChaCha20);
256 }
257
258 #[test]
259 fn rng_algorithm_serde_round_trip() {
260 let json = serde_json::to_string(&RngAlgorithm::ChaCha20).expect("serialize");
261 let back: RngAlgorithm = serde_json::from_str(&json).expect("deserialize");
262 assert_eq!(back, RngAlgorithm::ChaCha20);
263 }
264
265 #[test]
268 fn from_seed_equals_from_parts_with_zero_stream_and_zero_word_pos() {
269 let a = RngState::from_seed(seed_bytes(0x99));
270 let b = RngState::from_parts(seed_bytes(0x99), 0, 0);
271 assert_eq!(a, b);
272 }
273
274 #[test]
275 fn from_parts_preserves_all_four_fields() {
276 let s = RngState::from_parts([0xab; 32], 9_999_999, 12_345_678_901_234_567_890_u128);
277 assert_eq!(s.algorithm, RngAlgorithm::ChaCha20);
278 assert_eq!(s.seed, [0xab; 32]);
279 assert_eq!(s.stream, 9_999_999);
280 assert_eq!(s.word_pos, 12_345_678_901_234_567_890_u128);
281 }
282
283 #[test]
286 fn into_chacha_initializes_stream_and_word_pos_correctly() {
287 let s = RngState::from_parts(seed_bytes(0x42), 1234, 5678);
288 let rng = s.clone().into_chacha();
289 assert_eq!(rng.get_stream(), s.stream);
290 assert_eq!(rng.get_word_pos(), s.word_pos);
291 }
292
293 #[test]
294 fn into_chacha_clones_so_state_is_unchanged_by_draws() {
295 let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
296 let mut r = s.clone().into_chacha();
297 let _ = draw_n(&mut r, 4096);
298 assert_eq!(s.word_pos, 0);
301 }
302
303 #[test]
306 fn fork_is_pure_under_parent_stream_and_seed() {
307 let p1 = RngState::from_parts(seed_bytes(0x42), 7, 0);
310 let p2 = RngState::from_parts(seed_bytes(0x42), 7, 999);
311 assert_eq!(p1.fork(b"some-salt"), p2.fork(b"some-salt"));
312 }
313
314 #[test]
315 fn fork_keys_off_parent_stream_not_just_seed() {
316 let p1 = RngState::from_parts(seed_bytes(0x42), 0, 0);
319 let p2 = RngState::from_parts(seed_bytes(0x42), 1, 0);
320 assert_ne!(p1.fork(b"x"), p2.fork(b"x"));
321 }
322
323 #[test]
324 fn fork_keys_off_seed_not_just_stream() {
325 let p1 = RngState::from_parts(seed_bytes(0x42), 5, 0);
332 let p2 = RngState::from_parts(seed_bytes(0xbb), 5, 0);
333 let c1 = p1.fork(b"x");
334 let c2 = p2.fork(b"x");
335 assert_eq!(c1.seed, p1.seed);
336 assert_eq!(c2.seed, p2.seed);
337 assert_eq!(c1.stream, c2.stream);
340 let mut r1 = c1.into_chacha();
342 let mut r2 = c2.into_chacha();
343 assert_ne!(draw_n(&mut r1, 1024), draw_n(&mut r2, 1024));
344 }
345
346 #[test]
347 fn fork_with_empty_salt_is_deterministic() {
348 let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
349 let c1 = parent.fork(b"");
350 let c2 = parent.fork(b"");
351 assert_eq!(c1, c2);
352 }
353
354 #[test]
355 fn fork_with_empty_salt_differs_from_fork_with_nonempty_salt() {
356 let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
357 let c_empty = parent.fork(b"");
358 let c_one = parent.fork(b"x");
359 assert_ne!(c_empty.stream, c_one.stream);
360 }
361
362 #[test]
363 fn fork_is_not_commutative_under_double_fork() {
364 let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
366 let g_ab = parent.fork(b"a").fork(b"b");
367 let g_ba = parent.fork(b"b").fork(b"a");
368 assert_ne!(g_ab.stream, g_ba.stream);
369 }
370
371 #[test]
372 fn fork_chains_of_same_salt_strictly_descend_to_distinct_streams() {
373 let p = RngState::from_parts(seed_bytes(0x42), 0, 0);
377 let g1 = p.fork(b"s");
378 let g2 = g1.fork(b"s");
379 let g3 = g2.fork(b"s");
380 assert_ne!(p.stream, g1.stream);
381 assert_ne!(g1.stream, g2.stream);
382 assert_ne!(g2.stream, g3.stream);
383 assert_ne!(p.stream, g2.stream);
385 assert_ne!(p.stream, g3.stream);
386 assert_ne!(g1.stream, g3.stream);
387 }
388
389 #[test]
390 fn fork_long_salt_is_handled() {
391 let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
393 let long_salt = vec![0xab; 4096];
394 let c1 = parent.fork(&long_salt);
395 let c2 = parent.fork(&long_salt);
396 assert_eq!(c1, c2);
397 }
398
399 #[test]
400 fn fork_one_byte_salt_difference_changes_stream() {
401 let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
404 let c1 = parent.fork(b"abcdefgh");
405 let c2 = parent.fork(b"abcdefgi");
406 assert_ne!(c1.stream, c2.stream);
407 }
408
409 #[test]
412 fn snapshot_at_word_pos_zero_equals_a_fresh_state() {
413 let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
414 let r = s.clone().into_chacha();
415 let snap = RngState::snapshot(&r, &s);
416 assert_eq!(snap, s);
417 }
418
419 #[test]
420 fn snapshot_records_post_draw_word_pos() {
421 let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
422 let mut r = s.clone().into_chacha();
423 let _ = draw_n(&mut r, 4096);
424 let snap = RngState::snapshot(&r, &s);
425 assert!(snap.word_pos > 0);
426 assert_eq!(snap.word_pos, r.get_word_pos());
427 }
428
429 #[test]
430 fn snapshot_can_be_round_tripped_through_serde() {
431 let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
432 let mut r = s.clone().into_chacha();
433 let _ = draw_n(&mut r, 2048);
434 let snap = RngState::snapshot(&r, &s);
435 let json = serde_json::to_string(&snap).expect("serialize");
436 let back: RngState = serde_json::from_str(&json).expect("deserialize");
437 assert_eq!(back, snap);
438 }
439
440 #[test]
441 fn snapshot_matches_set_word_pos_resumption() {
442 let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
445 let mut r = s.clone().into_chacha();
446 let _ = draw_n(&mut r, 12_345);
447 let snap = RngState::snapshot(&r, &s);
448 let mut r_resumed = snap.into_chacha();
449 let cont = draw_n(&mut r, 4096);
451 let resumed = draw_n(&mut r_resumed, 4096);
452 assert_eq!(cont, resumed);
453 }
454
455 #[test]
458 fn distinct_streams_give_independent_byte_sequences() {
459 let seed = seed_bytes(0x42);
462 let mut draws: Vec<Vec<u8>> = Vec::new();
463 for stream in 0..4u64 {
464 let mut r = RngState::from_parts(seed, stream, 0).into_chacha();
465 draws.push(draw_n(&mut r, 1024));
466 }
467 for i in 0..4 {
468 for j in (i + 1)..4 {
469 assert_ne!(draws[i], draws[j], "stream {i} == stream {j}");
470 }
471 }
472 }
473
474 #[test]
475 fn streams_with_distinct_seeds_are_independent() {
476 let mut r1 = RngState::from_parts(seed_bytes(0xaa), 0, 0).into_chacha();
477 let mut r2 = RngState::from_parts(seed_bytes(0xbb), 0, 0).into_chacha();
478 assert_ne!(draw_n(&mut r1, 1024), draw_n(&mut r2, 1024));
479 }
480}