Skip to main content

typhoon/utils/
random.rs

1#[cfg(test)]
2use rand::SeedableRng;
3use rand::rngs::OsRng;
4use rand::{CryptoRng, Rng, RngCore};
5
6use crate::bytes::FixedByteBuffer;
7
8/// Extension methods on top of the standard `Rng` interface.
9pub trait SupportRng {
10    fn random_byte_array<const T: usize>(&mut self) -> [u8; T];
11
12    fn random_byte_buffer<const T: usize>(&mut self) -> FixedByteBuffer<T>;
13
14    #[cfg(feature = "client")]
15    fn random_item<'a, T>(&mut self, slice: &'a [T]) -> Option<&'a T>;
16}
17
18// ── TyphoonRng ────────────────────────────────────────────────────────────────
19
20/// Unified RNG wrapper used throughout the codebase.
21///
22/// In production this is always backed by `OsRng`.  In test builds, calling
23/// [`set_test_rng_seed`] replaces it with a deterministic `StdRng` for the
24/// current thread, making packet-construction randomness reproducible.
25pub enum TyphoonRng {
26    Os(OsRng),
27    #[cfg(test)]
28    Seeded(rand::rngs::StdRng),
29}
30
31impl RngCore for TyphoonRng {
32    fn next_u32(&mut self) -> u32 {
33        match self {
34            TyphoonRng::Os(r) => r.next_u32(),
35            #[cfg(test)]
36            TyphoonRng::Seeded(r) => r.next_u32(),
37        }
38    }
39
40    fn next_u64(&mut self) -> u64 {
41        match self {
42            TyphoonRng::Os(r) => r.next_u64(),
43            #[cfg(test)]
44            TyphoonRng::Seeded(r) => r.next_u64(),
45        }
46    }
47
48    fn fill_bytes(&mut self, dest: &mut [u8]) {
49        match self {
50            TyphoonRng::Os(r) => r.fill_bytes(dest),
51            #[cfg(test)]
52            TyphoonRng::Seeded(r) => r.fill_bytes(dest),
53        }
54    }
55
56    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
57        match self {
58            TyphoonRng::Os(r) => r.try_fill_bytes(dest),
59            #[cfg(test)]
60            TyphoonRng::Seeded(r) => r.try_fill_bytes(dest),
61        }
62    }
63}
64
65/// `TyphoonRng` is considered crypto-safe: the production variant uses `OsRng`
66/// directly, and the seeded variant is only reachable in test builds where the
67/// caller explicitly opts in to determinism.
68impl CryptoRng for TyphoonRng {}
69
70impl SupportRng for TyphoonRng {
71    fn random_byte_array<const T: usize>(&mut self) -> [u8; T] {
72        let mut buf = [0u8; T];
73        self.fill_bytes(&mut buf);
74        buf
75    }
76
77    fn random_byte_buffer<const T: usize>(&mut self) -> FixedByteBuffer<T> {
78        FixedByteBuffer::from_array(self.random_byte_array::<T>())
79    }
80
81    #[cfg(feature = "client")]
82    fn random_item<'a, T>(&mut self, slice: &'a [T]) -> Option<&'a T> {
83        if slice.is_empty() {
84            None
85        } else {
86            Some(&slice[self.gen_range(0..slice.len())])
87        }
88    }
89}
90
91// ── Test seed management ──────────────────────────────────────────────────────
92
93#[cfg(test)]
94use std::cell::RefCell;
95
96// Per-thread seeded RNG used when a test calls [`set_test_rng_seed`].
97// Each [`get_rng`] call forks an independent `StdRng` from this state so
98// callers advance independently but the overall sequence is deterministic.
99#[cfg(test)]
100thread_local! {
101    static TEST_RNG: RefCell<Option<rand::rngs::StdRng>> = const { RefCell::new(None) };
102}
103
104/// Seed the per-thread deterministic RNG for the current test.
105///
106/// After calling this, every [`get_rng`] call on this thread returns a
107/// deterministic `StdRng` forked from the shared state rather than `OsRng`.
108/// Call with different seeds in different tests to get independent sequences.
109#[cfg(test)]
110pub fn set_test_rng_seed(seed: u64) {
111    TEST_RNG.with(|r| *r.borrow_mut() = Some(rand::rngs::StdRng::seed_from_u64(seed)));
112}
113
114#[cfg(test)]
115/// Reset the per-thread RNG back to `OsRng` (undo [`set_test_rng_seed`]).
116pub fn clear_test_rng() {
117    TEST_RNG.with(|r| *r.borrow_mut() = None);
118}
119
120// ── Factory ───────────────────────────────────────────────────────────────────
121
122/// Return a `TyphoonRng` for this call site.
123///
124/// In production: always `OsRng`.
125/// In tests: if [`set_test_rng_seed`] was called, forks a deterministic
126/// `StdRng` from the thread-local state (so each call gets an independent
127/// but reproducible sequence); otherwise falls back to `OsRng`.
128#[inline]
129pub fn get_rng() -> TyphoonRng {
130    #[cfg(test)]
131    {
132        let forked = TEST_RNG.with(|r| r.borrow_mut().as_mut().map(|rng| rand::rngs::StdRng::seed_from_u64(rng.next_u64())));
133        if let Some(seeded) = forked {
134            return TyphoonRng::Seeded(seeded);
135        }
136    }
137    TyphoonRng::Os(OsRng)
138}
139
140#[cfg(test)]
141#[path = "../../tests/utils/random.rs"]
142mod tests;
143
144/// Sample a chunk size around `chunk` with two-sided `jitter`, clamped to `[1, max_payload]`, `chunk == 0` is the sentinel for "saturate the MTU".
145#[inline]
146pub fn jittered_chunk_size(max_payload: usize, chunk: usize, jitter: f64) -> usize {
147    let target = if chunk == 0 {
148        max_payload
149    } else {
150        chunk
151    };
152    if max_payload <= 1 {
153        return max_payload;
154    }
155    let target_f = target as f64;
156    let delta = (target_f * jitter).round() as usize;
157    let lo = target.saturating_sub(delta).max(1);
158    let hi = target.saturating_add(delta).min(max_payload);
159    if lo >= hi {
160        return hi;
161    }
162    get_rng().gen_range(lo..=hi)
163}
164
165/// Picks one of several branches at random, weighted by the per-branch weights, and
166/// evaluates the chosen branch as an expression (its value is the value of the macro).
167///
168/// Each branch is either `weight => body` or just `body` (implied weight `1u32`).
169/// Weights must be `u32` expressions; bodies must all evaluate to the same type.
170/// Branches are separated by commas; trailing block bodies may omit the comma.
171#[macro_export]
172macro_rules! weighted_random {
173    // ── Final step: emit the weighted dispatch ────────────────────────────────
174    (@parse {} -> ($($weights:expr,)*) ($($bodies:expr,)*)) => {{
175        use weighted_rand::builder::NewBuilder as _;
176        let __weights: &[u32] = &[$( ($weights) as u32 ),*];
177        let __table = weighted_rand::builder::WalkerTableBuilder::new(__weights).build();
178        let mut __rng = $crate::utils::random::get_rng();
179        let __idx = __table.next_rng(&mut __rng);
180        'wr: {
181            let mut __i = 0usize;
182            $(
183                if __idx == __i { break 'wr ($bodies); }
184                #[allow(unused_assignments)]
185                { __i += 1; }
186            )*
187            unreachable!()
188        }
189    }};
190
191    // ── Skip leading comma (allows trailing/leading commas naturally) ─────────
192    (@parse {, $($rest:tt)*} -> ($($weights:expr,)*) ($($bodies:expr,)*)) => {
193        $crate::weighted_random!(@parse {$($rest)*} -> ($($weights,)*) ($($bodies,)*))
194    };
195
196    // ── `weight => { block }` followed by more (no trailing comma needed) ─────
197    (@parse {$weight:expr => $body:block $($rest:tt)*} -> ($($weights:expr,)*) ($($bodies:expr,)*)) => {
198        $crate::weighted_random!(@parse {$($rest)*} -> ($($weights,)* $weight,) ($($bodies,)* $body,))
199    };
200
201    // ── Bare `{ block }` followed by more (no trailing comma needed) ──────────
202    (@parse {$body:block $($rest:tt)*} -> ($($weights:expr,)*) ($($bodies:expr,)*)) => {
203        $crate::weighted_random!(@parse {$($rest)*} -> ($($weights,)* 1u32,) ($($bodies,)* $body,))
204    };
205
206    // ── `weight => expr, ...` ─────────────────────────────────────────────────
207    (@parse {$weight:expr => $body:expr, $($rest:tt)*} -> ($($weights:expr,)*) ($($bodies:expr,)*)) => {
208        $crate::weighted_random!(@parse {$($rest)*} -> ($($weights,)* $weight,) ($($bodies,)* $body,))
209    };
210    // ── final `weight => expr` (no trailing comma) ────────────────────────────
211    (@parse {$weight:expr => $body:expr} -> ($($weights:expr,)*) ($($bodies:expr,)*)) => {
212        $crate::weighted_random!(@parse {} -> ($($weights,)* $weight,) ($($bodies,)* $body,))
213    };
214
215    // ── bare `expr, ...` ──────────────────────────────────────────────────────
216    (@parse {$body:expr, $($rest:tt)*} -> ($($weights:expr,)*) ($($bodies:expr,)*)) => {
217        $crate::weighted_random!(@parse {$($rest)*} -> ($($weights,)* 1u32,) ($($bodies,)* $body,))
218    };
219    // ── final bare `expr` (no trailing comma) ─────────────────────────────────
220    (@parse {$body:expr} -> ($($weights:expr,)*) ($($bodies:expr,)*)) => {
221        $crate::weighted_random!(@parse {} -> ($($weights,)* 1u32,) ($($bodies,)* $body,))
222    };
223
224    // ── Catch-all: malformed @parse input fails fast (prevents infinite recursion
225    //    via the entry arm below).
226    (@parse $($_rest:tt)*) => {
227        ::core::compile_error!(
228            "malformed `weighted_random!` input — expected comma-separated `weight => expr` or `expr` branches"
229        )
230    };
231
232    // ── Entry point ───────────────────────────────────────────────────────────
233    ($($input:tt)*) => {
234        $crate::weighted_random!(@parse {$($input)*} -> () ())
235    };
236}