Skip to main content

salib_core/
rng.rs

1//! `RngState` — multi-stream `ChaCha20` with deterministic salt-derived
2//! forking. The replay-determinism foundation for every salib sampler
3//! and estimator.
4//!
5//! # The state
6//!
7//! Four fields:
8//! - `algorithm`: closed enum, `ChaCha20` only today (future
9//!   `Pcg64` / `Xoshiro256pp` behind a `fast-rng` feature).
10//! - `seed: [u8; 32]`: 256-bit seed.
11//! - `stream: u64`: 2⁶⁴ independent streams per seed (`ChaCha20Rng::set_stream`).
12//! - `word_pos: u128`: position within the stream
13//!   (`ChaCha20Rng::get_word_pos` / `set_word_pos`); enables mid-flight
14//!   snapshot + resumption.
15//!
16//! Recording all four lets a verifier reconstruct any SA campaign's
17//! RNG stream from scratch.
18//!
19//! # Forking
20//!
21//! [`RngState::fork`] derives a child stream from a salt:
22//!
23//! ```text
24//! child.stream   = parent.stream XOR u64::from_le_bytes(SHA-256(parent.stream || salt)[..8])
25//! child.word_pos = 0
26//! child.seed     = parent.seed
27//! ```
28//!
29//! Pure function of `(parent.stream, salt)` and the parent's seed —
30//! `parent.fork(b"block-7")` always yields the same child regardless
31//! of process, machine, rayon thread count, or wall-clock time. This
32//! is what makes parallel sampling deterministic: rayon workers fork
33//! by block index, the resulting per-block streams are stable across
34//! runs, and the sample matrix is bit-identical regardless of how
35//! work was distributed.
36//!
37//! Why XOR-with-mix instead of replace-with-mix: the child's stream
38//! depends on *both* `parent.stream` and `salt`, not just `salt`. Two
39//! distinct parents that happen to fork with the same salt produce
40//! distinct children — important for nested forking patterns where
41//! the same salt vocabulary recurs at multiple levels.
42//!
43//! Why SHA-256 and not Blake3: SHA-256 is indistinguishable from
44//! Blake3 at this small input size (well under 64 bytes). Blake3
45//! lands in a later PR when there is a tree-mode-parallel-hashing
46//! use case (e.g. content-addressing a 10⁶-row sample matrix).
47
48use rand_chacha::rand_core::SeedableRng;
49use rand_chacha::ChaCha20Rng;
50use serde::{Deserialize, Serialize};
51use sha2::{Digest, Sha256};
52
53/// The set of RNG algorithms salib supports. Closed enum +
54/// `#[non_exhaustive]`; future variants land non-breaking.
55#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56#[non_exhaustive]
57pub enum RngAlgorithm {
58    /// `ChaCha20` (RFC 7539). 256-bit seed, 2⁶⁴ streams per seed,
59    /// 2¹²⁸ word-positions per stream. The default.
60    #[default]
61    ChaCha20,
62}
63
64/// The serializable RNG state. The single source of truth for any
65/// SA campaign's randomness.
66#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
67#[non_exhaustive]
68pub struct RngState {
69    pub algorithm: RngAlgorithm,
70    pub seed: [u8; 32],
71    pub stream: u64,
72    pub word_pos: u128,
73}
74
75impl RngState {
76    /// Construct a fresh state from a 32-byte seed. `ChaCha20`, stream
77    /// 0, `word_pos` 0.
78    #[must_use]
79    pub fn from_seed(seed: [u8; 32]) -> Self {
80        Self {
81            algorithm: RngAlgorithm::ChaCha20,
82            seed,
83            stream: 0,
84            word_pos: 0,
85        }
86    }
87
88    /// Construct from explicit fields. Only path into a non-default
89    /// `(stream, word_pos)` outside of `fork` and `snapshot`.
90    #[must_use]
91    pub fn from_parts(seed: [u8; 32], stream: u64, word_pos: u128) -> Self {
92        Self {
93            algorithm: RngAlgorithm::ChaCha20,
94            seed,
95            stream,
96            word_pos,
97        }
98    }
99
100    /// Derive a child state from a salt. The child shares the
101    /// parent's seed; the child stream is
102    /// `parent.stream XOR u64::from_le_bytes(SHA-256(parent.stream || salt)[..8])`.
103    /// The child's `word_pos` is reset to 0.
104    ///
105    /// Pure function of `(parent.stream, parent.seed, salt)`. Same
106    /// inputs always produce equal outputs.
107    #[must_use]
108    pub fn fork(&self, salt: &[u8]) -> Self {
109        let mut hasher = Sha256::new();
110        hasher.update(self.stream.to_le_bytes());
111        hasher.update(salt);
112        let digest = hasher.finalize();
113        let mut buf = [0u8; 8];
114        buf.copy_from_slice(&digest[..8]);
115        let mix = u64::from_le_bytes(buf);
116        Self {
117            algorithm: self.algorithm,
118            seed: self.seed,
119            stream: self.stream ^ mix,
120            word_pos: 0,
121        }
122    }
123
124    /// Construct a `ChaCha20Rng` initialized to this state's
125    /// `(seed, stream, word_pos)`. The handed-out RNG is detached
126    /// from this `RngState` — mutations to the returned RNG do not
127    /// update `self`.
128    #[must_use]
129    pub fn into_chacha(self) -> ChaCha20Rng {
130        let mut rng = ChaCha20Rng::from_seed(self.seed);
131        rng.set_stream(self.stream);
132        rng.set_word_pos(self.word_pos);
133        rng
134    }
135
136    /// Snapshot a `ChaCha20Rng`'s current `(stream, word_pos)` back
137    /// into a serializable `RngState`, preserving `algorithm` and
138    /// `seed` from `parent`. Used to record mid-flight RNG state
139    /// for resumption / audit.
140    ///
141    /// `parent` carries the `seed` because `ChaCha20Rng` does not
142    /// expose its seed once constructed; the caller is responsible
143    /// for handing a `parent` whose seed matches the RNG's actual
144    /// seed (typically the same `RngState` that produced the RNG via
145    /// `into_chacha`).
146    #[must_use]
147    pub fn snapshot(rng: &ChaCha20Rng, parent: &RngState) -> Self {
148        Self {
149            algorithm: parent.algorithm,
150            seed: parent.seed,
151            stream: rng.get_stream(),
152            word_pos: rng.get_word_pos(),
153        }
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use rand_chacha::rand_core::RngCore;
161
162    fn seed_bytes(b: u8) -> [u8; 32] {
163        [b; 32]
164    }
165
166    fn draw_n(rng: &mut ChaCha20Rng, n: usize) -> Vec<u8> {
167        let mut buf = vec![0u8; n];
168        rng.fill_bytes(&mut buf);
169        buf
170    }
171
172    #[test]
173    fn from_seed_defaults_stream_and_word_pos_to_zero() {
174        let s = RngState::from_seed(seed_bytes(0x42));
175        assert_eq!(s.stream, 0);
176        assert_eq!(s.word_pos, 0);
177        assert_eq!(s.algorithm, RngAlgorithm::ChaCha20);
178        assert_eq!(s.seed, seed_bytes(0x42));
179    }
180
181    #[test]
182    fn same_state_produces_identical_bytes() {
183        let s = RngState::from_parts(seed_bytes(0x42), 7, 0);
184        let mut a = s.clone().into_chacha();
185        let mut b = s.into_chacha();
186        assert_eq!(draw_n(&mut a, 1024), draw_n(&mut b, 1024));
187    }
188
189    #[test]
190    fn fork_with_same_salt_is_deterministic() {
191        let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
192        let c1 = parent.fork(b"block-0");
193        let c2 = parent.fork(b"block-0");
194        assert_eq!(c1, c2);
195    }
196
197    #[test]
198    fn fork_with_distinct_salts_produces_distinct_streams() {
199        let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
200        let c1 = parent.fork(b"block-0");
201        let c2 = parent.fork(b"block-1");
202        assert_ne!(c1.stream, c2.stream);
203        let mut r1 = c1.into_chacha();
204        let mut r2 = c2.into_chacha();
205        assert_ne!(draw_n(&mut r1, 1024), draw_n(&mut r2, 1024));
206    }
207
208    #[test]
209    fn fork_resets_child_word_pos() {
210        let parent = RngState::from_parts(seed_bytes(0x42), 100, 999);
211        let child = parent.fork(b"any");
212        assert_eq!(child.word_pos, 0);
213        assert_eq!(child.seed, parent.seed);
214    }
215
216    #[test]
217    fn fork_xors_with_mix_so_distinct_parents_produce_distinct_children_under_same_salt() {
218        let p1 = RngState::from_parts(seed_bytes(0x42), 1, 0);
219        let p2 = RngState::from_parts(seed_bytes(0x42), 2, 0);
220        let c1 = p1.fork(b"same-salt");
221        let c2 = p2.fork(b"same-salt");
222        assert_ne!(c1.stream, c2.stream);
223    }
224
225    #[test]
226    fn snapshot_round_trips_word_pos_and_stream() {
227        let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
228        let mut r = s.clone().into_chacha();
229        let _ = draw_n(&mut r, 8192);
230        let snap = RngState::snapshot(&r, &s);
231        // The snapshot's word_pos reflects the post-draw position.
232        assert_eq!(snap.word_pos, r.get_word_pos());
233        // Drawing 1024 more from the original vs from a fresh chacha
234        // initialized from the snapshot must agree.
235        let mut original_continued = r;
236        let mut resumed = snap.into_chacha();
237        assert_eq!(
238            draw_n(&mut original_continued, 1024),
239            draw_n(&mut resumed, 1024)
240        );
241    }
242
243    #[test]
244    fn rngstate_serde_round_trip() {
245        let s = RngState::from_parts(seed_bytes(0x42), 12345, 67890);
246        let json = serde_json::to_string(&s).expect("serialize");
247        let back: RngState = serde_json::from_str(&json).expect("deserialize");
248        assert_eq!(back, s);
249    }
250
251    // ── RngAlgorithm sanity ───────────────────────────────────────────
252
253    #[test]
254    fn rng_algorithm_default_is_chacha20() {
255        assert_eq!(RngAlgorithm::default(), RngAlgorithm::ChaCha20);
256    }
257
258    #[test]
259    fn rng_algorithm_serde_round_trip() {
260        let json = serde_json::to_string(&RngAlgorithm::ChaCha20).expect("serialize");
261        let back: RngAlgorithm = serde_json::from_str(&json).expect("deserialize");
262        assert_eq!(back, RngAlgorithm::ChaCha20);
263    }
264
265    // ── RngState construction equivalences ────────────────────────────
266
267    #[test]
268    fn from_seed_equals_from_parts_with_zero_stream_and_zero_word_pos() {
269        let a = RngState::from_seed(seed_bytes(0x99));
270        let b = RngState::from_parts(seed_bytes(0x99), 0, 0);
271        assert_eq!(a, b);
272    }
273
274    #[test]
275    fn from_parts_preserves_all_four_fields() {
276        let s = RngState::from_parts([0xab; 32], 9_999_999, 12_345_678_901_234_567_890_u128);
277        assert_eq!(s.algorithm, RngAlgorithm::ChaCha20);
278        assert_eq!(s.seed, [0xab; 32]);
279        assert_eq!(s.stream, 9_999_999);
280        assert_eq!(s.word_pos, 12_345_678_901_234_567_890_u128);
281    }
282
283    // ── into_chacha / get_stream + get_word_pos round-trips ─────────
284
285    #[test]
286    fn into_chacha_initializes_stream_and_word_pos_correctly() {
287        let s = RngState::from_parts(seed_bytes(0x42), 1234, 5678);
288        let rng = s.clone().into_chacha();
289        assert_eq!(rng.get_stream(), s.stream);
290        assert_eq!(rng.get_word_pos(), s.word_pos);
291    }
292
293    #[test]
294    fn into_chacha_clones_so_state_is_unchanged_by_draws() {
295        let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
296        let mut r = s.clone().into_chacha();
297        let _ = draw_n(&mut r, 4096);
298        // s itself is untouched (s.word_pos == 0 still); the consumed
299        // RNG is detached.
300        assert_eq!(s.word_pos, 0);
301    }
302
303    // ── Fork: deeper properties ───────────────────────────────────────
304
305    #[test]
306    fn fork_is_pure_under_parent_stream_and_seed() {
307        // Two parents with identical (stream, seed) and any word_pos
308        // value must produce the same fork — fork ignores parent.word_pos.
309        let p1 = RngState::from_parts(seed_bytes(0x42), 7, 0);
310        let p2 = RngState::from_parts(seed_bytes(0x42), 7, 999);
311        assert_eq!(p1.fork(b"some-salt"), p2.fork(b"some-salt"));
312    }
313
314    #[test]
315    fn fork_keys_off_parent_stream_not_just_seed() {
316        // Two parents with same seed but different streams must
317        // produce different forks even with the same salt.
318        let p1 = RngState::from_parts(seed_bytes(0x42), 0, 0);
319        let p2 = RngState::from_parts(seed_bytes(0x42), 1, 0);
320        assert_ne!(p1.fork(b"x"), p2.fork(b"x"));
321    }
322
323    #[test]
324    fn fork_keys_off_seed_not_just_stream() {
325        // Two parents with same stream but different seeds must
326        // produce different forks even with the same salt — actually,
327        // fork derivation today is `parent.stream ^ SHA256(parent.stream || salt)[..8]`,
328        // which is *seed-independent* in the stream value; the child's
329        // seed differs because it's copied from the parent. So
330        // assert seed equality, not stream equality.
331        let p1 = RngState::from_parts(seed_bytes(0x42), 5, 0);
332        let p2 = RngState::from_parts(seed_bytes(0xbb), 5, 0);
333        let c1 = p1.fork(b"x");
334        let c2 = p2.fork(b"x");
335        assert_eq!(c1.seed, p1.seed);
336        assert_eq!(c2.seed, p2.seed);
337        // Stream value coincides because the SHA256 input is
338        // (parent.stream, salt) — neither side carries the seed.
339        assert_eq!(c1.stream, c2.stream);
340        // Bytes drawn from the chacha differ because the seed differs.
341        let mut r1 = c1.into_chacha();
342        let mut r2 = c2.into_chacha();
343        assert_ne!(draw_n(&mut r1, 1024), draw_n(&mut r2, 1024));
344    }
345
346    #[test]
347    fn fork_with_empty_salt_is_deterministic() {
348        let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
349        let c1 = parent.fork(b"");
350        let c2 = parent.fork(b"");
351        assert_eq!(c1, c2);
352    }
353
354    #[test]
355    fn fork_with_empty_salt_differs_from_fork_with_nonempty_salt() {
356        let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
357        let c_empty = parent.fork(b"");
358        let c_one = parent.fork(b"x");
359        assert_ne!(c_empty.stream, c_one.stream);
360    }
361
362    #[test]
363    fn fork_is_not_commutative_under_double_fork() {
364        // grandchild via fork-then-fork — order of salts matters.
365        let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
366        let g_ab = parent.fork(b"a").fork(b"b");
367        let g_ba = parent.fork(b"b").fork(b"a");
368        assert_ne!(g_ab.stream, g_ba.stream);
369    }
370
371    #[test]
372    fn fork_chains_of_same_salt_strictly_descend_to_distinct_streams() {
373        // parent → fork(s) → fork(s).fork(s) → fork(s).fork(s).fork(s)
374        // Each generation gets a new stream value (modulo astronomical
375        // SHA-256 collisions).
376        let p = RngState::from_parts(seed_bytes(0x42), 0, 0);
377        let g1 = p.fork(b"s");
378        let g2 = g1.fork(b"s");
379        let g3 = g2.fork(b"s");
380        assert_ne!(p.stream, g1.stream);
381        assert_ne!(g1.stream, g2.stream);
382        assert_ne!(g2.stream, g3.stream);
383        // And no transitive aliasing:
384        assert_ne!(p.stream, g2.stream);
385        assert_ne!(p.stream, g3.stream);
386        assert_ne!(g1.stream, g3.stream);
387    }
388
389    #[test]
390    fn fork_long_salt_is_handled() {
391        // Long salt (>>SHA-256 block size) — sha2 absorbs it normally.
392        let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
393        let long_salt = vec![0xab; 4096];
394        let c1 = parent.fork(&long_salt);
395        let c2 = parent.fork(&long_salt);
396        assert_eq!(c1, c2);
397    }
398
399    #[test]
400    fn fork_one_byte_salt_difference_changes_stream() {
401        // SHA-256 cascade ensures any salt bit difference flips ~half
402        // the digest bits; stream values must differ.
403        let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
404        let c1 = parent.fork(b"abcdefgh");
405        let c2 = parent.fork(b"abcdefgi");
406        assert_ne!(c1.stream, c2.stream);
407    }
408
409    // ── Snapshot semantics ────────────────────────────────────────────
410
411    #[test]
412    fn snapshot_at_word_pos_zero_equals_a_fresh_state() {
413        let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
414        let r = s.clone().into_chacha();
415        let snap = RngState::snapshot(&r, &s);
416        assert_eq!(snap, s);
417    }
418
419    #[test]
420    fn snapshot_records_post_draw_word_pos() {
421        let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
422        let mut r = s.clone().into_chacha();
423        let _ = draw_n(&mut r, 4096);
424        let snap = RngState::snapshot(&r, &s);
425        assert!(snap.word_pos > 0);
426        assert_eq!(snap.word_pos, r.get_word_pos());
427    }
428
429    #[test]
430    fn snapshot_can_be_round_tripped_through_serde() {
431        let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
432        let mut r = s.clone().into_chacha();
433        let _ = draw_n(&mut r, 2048);
434        let snap = RngState::snapshot(&r, &s);
435        let json = serde_json::to_string(&snap).expect("serialize");
436        let back: RngState = serde_json::from_str(&json).expect("deserialize");
437        assert_eq!(back, snap);
438    }
439
440    #[test]
441    fn snapshot_matches_set_word_pos_resumption() {
442        // Snapshot at any word_pos K, then resume produces the same
443        // bytes as continuing the original RNG.
444        let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
445        let mut r = s.clone().into_chacha();
446        let _ = draw_n(&mut r, 12_345);
447        let snap = RngState::snapshot(&r, &s);
448        let mut r_resumed = snap.into_chacha();
449        // Compare the next 4096 bytes.
450        let cont = draw_n(&mut r, 4096);
451        let resumed = draw_n(&mut r_resumed, 4096);
452        assert_eq!(cont, resumed);
453    }
454
455    // ── Cross-stream isolation ────────────────────────────────────────
456
457    #[test]
458    fn distinct_streams_give_independent_byte_sequences() {
459        // Streams 0..4 with the same seed produce mutually distinct
460        // first-1024-byte draws.
461        let seed = seed_bytes(0x42);
462        let mut draws: Vec<Vec<u8>> = Vec::new();
463        for stream in 0..4u64 {
464            let mut r = RngState::from_parts(seed, stream, 0).into_chacha();
465            draws.push(draw_n(&mut r, 1024));
466        }
467        for i in 0..4 {
468            for j in (i + 1)..4 {
469                assert_ne!(draws[i], draws[j], "stream {i} == stream {j}");
470            }
471        }
472    }
473
474    #[test]
475    fn streams_with_distinct_seeds_are_independent() {
476        let mut r1 = RngState::from_parts(seed_bytes(0xaa), 0, 0).into_chacha();
477        let mut r2 = RngState::from_parts(seed_bytes(0xbb), 0, 0).into_chacha();
478        assert_ne!(draw_n(&mut r1, 1024), draw_n(&mut r2, 1024));
479    }
480}