tensorlogic_train/augmentation/rng.rs
1/// A simple deterministic Linear Congruential Generator (LCG) RNG.
2///
3/// This is intentionally lightweight and avoids pulling in the `rand` crate.
4/// Constants follow Knuth's MMIX multiplier.
5#[derive(Debug, Clone)]
6pub struct AugRng {
7 state: u64,
8}
9
10impl AugRng {
11 /// Seed the RNG.
12 pub fn new(seed: u64) -> Self {
13 // Mix seed so that seed=0 does not stay stuck.
14 let state = seed.wrapping_add(0x9e37_79b9_7f4a_7c15);
15 Self { state }
16 }
17
18 /// Advance the LCG and return the next raw 64-bit value.
19 #[inline]
20 fn next_u64(&mut self) -> u64 {
21 // LCG: state = a * state + c (mod 2^64)
22 // Constants from Knuth's MMIX.
23 self.state = self
24 .state
25 .wrapping_mul(6_364_136_223_846_793_005)
26 .wrapping_add(1_442_695_040_888_963_407);
27 self.state
28 }
29
30 /// Uniform float in [0, 1).
31 pub fn next_f64(&mut self) -> f64 {
32 // Use upper 53 bits for a clean mantissa.
33 let bits = self.next_u64() >> 11;
34 (bits as f64) * (1.0 / (1u64 << 53) as f64)
35 }
36
37 /// Standard normal sample via Box-Muller transform.
38 pub fn next_normal(&mut self) -> f64 {
39 // Box-Muller: requires two uniform samples in (0, 1].
40 let u1 = (self.next_f64() + 1e-300).min(1.0); // avoid ln(0)
41 let u2 = self.next_f64();
42 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
43 }
44
45 /// Return `true` with probability `p`.
46 pub fn next_bool(&mut self, p: f64) -> bool {
47 self.next_f64() < p
48 }
49
50 /// Uniform integer in [0, max).
51 pub fn next_usize(&mut self, max: usize) -> usize {
52 if max == 0 {
53 return 0;
54 }
55 // Rejection-free scaling via 128-bit trick is overkill here; modulo is fine
56 // for non-cryptographic use.
57 (self.next_u64() as usize) % max
58 }
59}
60
61/// Sample λ ~ Beta(alpha, alpha).
62///
63/// For alpha == 1 this degenerates to Uniform(0,1).
64/// For other alpha values we use the normal approximation:
65/// λ ≈ clip( 0.5 + N(0,1) * 0.5 / sqrt(2*alpha), 0, 1 )
66/// which matches the median and spread of Beta(alpha, alpha) reasonably well.
67pub(crate) fn sample_beta_symmetric(alpha: f64, rng: &mut AugRng) -> f64 {
68 if (alpha - 1.0).abs() < 1e-9 {
69 rng.next_f64()
70 } else {
71 let sigma = 0.5 / (2.0 * alpha).sqrt();
72 let z = rng.next_normal();
73 (0.5 + z * sigma).clamp(0.0, 1.0)
74 }
75}