Skip to main content

tensorlogic_train/augmentation/
rng.rs

1/// A simple deterministic Linear Congruential Generator (LCG) RNG.
2///
3/// This is intentionally lightweight and avoids pulling in the `rand` crate.
4/// Constants follow Knuth's MMIX multiplier.
5#[derive(Debug, Clone)]
6pub struct AugRng {
7    state: u64,
8}
9
10impl AugRng {
11    /// Seed the RNG.
12    pub fn new(seed: u64) -> Self {
13        // Mix seed so that seed=0 does not stay stuck.
14        let state = seed.wrapping_add(0x9e37_79b9_7f4a_7c15);
15        Self { state }
16    }
17
18    /// Advance the LCG and return the next raw 64-bit value.
19    #[inline]
20    fn next_u64(&mut self) -> u64 {
21        // LCG: state = a * state + c  (mod 2^64)
22        // Constants from Knuth's MMIX.
23        self.state = self
24            .state
25            .wrapping_mul(6_364_136_223_846_793_005)
26            .wrapping_add(1_442_695_040_888_963_407);
27        self.state
28    }
29
30    /// Uniform float in [0, 1).
31    pub fn next_f64(&mut self) -> f64 {
32        // Use upper 53 bits for a clean mantissa.
33        let bits = self.next_u64() >> 11;
34        (bits as f64) * (1.0 / (1u64 << 53) as f64)
35    }
36
37    /// Standard normal sample via Box-Muller transform.
38    pub fn next_normal(&mut self) -> f64 {
39        // Box-Muller: requires two uniform samples in (0, 1].
40        let u1 = (self.next_f64() + 1e-300).min(1.0); // avoid ln(0)
41        let u2 = self.next_f64();
42        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
43    }
44
45    /// Return `true` with probability `p`.
46    pub fn next_bool(&mut self, p: f64) -> bool {
47        self.next_f64() < p
48    }
49
50    /// Uniform integer in [0, max).
51    pub fn next_usize(&mut self, max: usize) -> usize {
52        if max == 0 {
53            return 0;
54        }
55        // Rejection-free scaling via 128-bit trick is overkill here; modulo is fine
56        // for non-cryptographic use.
57        (self.next_u64() as usize) % max
58    }
59}
60
61/// Sample λ ~ Beta(alpha, alpha).
62///
63/// For alpha == 1 this degenerates to Uniform(0,1).
64/// For other alpha values we use the normal approximation:
65///   λ ≈ clip( 0.5 + N(0,1) * 0.5 / sqrt(2*alpha), 0, 1 )
66/// which matches the median and spread of Beta(alpha, alpha) reasonably well.
67pub(crate) fn sample_beta_symmetric(alpha: f64, rng: &mut AugRng) -> f64 {
68    if (alpha - 1.0).abs() < 1e-9 {
69        rng.next_f64()
70    } else {
71        let sigma = 0.5 / (2.0 * alpha).sqrt();
72        let z = rng.next_normal();
73        (0.5 + z * sigma).clamp(0.0, 1.0)
74    }
75}