Skip to main content

tinyquant_core/codec/
gaussian.rs

1//! Deterministic ChaCha20-backed standard-normal f64 stream.
2//!
3//! This is the canonical Gaussian source for Rust-side rotation matrix
4//! generation. It uses `ChaCha20` (via `rand_chacha::ChaCha20Rng`) for the
5//! uniform stream and the polar Box-Muller transform to emit independent
6//! `N(0, 1)` samples in `f64`.
7//!
8//! The exact recipe is specified in
9//! `docs/design/rust/numerical-semantics.md` §R1 and matches the pseudo-
10//! code in `docs/plans/rust/phase-13-rotation-numerics.md`. Any change
11//! here invalidates every rotation fixture under
12//! `tests/fixtures/rotation/`.
13
14use libm::{cos, log, sin, sqrt};
15use rand_chacha::rand_core::{RngCore, SeedableRng};
16use rand_chacha::ChaCha20Rng;
17
18/// A deterministic standard-normal `f64` generator seeded from `u64`.
19///
20/// Box-Muller turns two uniform samples into two normals; the second
21/// normal is cached in `spare` and returned on the next call.
22pub(crate) struct ChaChaGaussianStream {
23    rng: ChaCha20Rng,
24    spare: Option<f64>,
25}
26
27impl ChaChaGaussianStream {
28    /// Seed a new stream using `ChaCha20Rng::seed_from_u64(seed)`.
29    pub fn new(seed: u64) -> Self {
30        Self {
31            rng: ChaCha20Rng::seed_from_u64(seed),
32            spare: None,
33        }
34    }
35
36    /// Return the next standard-normal sample.
37    pub fn next_f64(&mut self) -> f64 {
38        if let Some(spare) = self.spare.take() {
39            return spare;
40        }
41        // Polar (classical) Box-Muller: reject u1 == 0 to avoid ln(0).
42        loop {
43            let u1 = self.next_uniform();
44            let u2 = self.next_uniform();
45            if u1 <= 0.0 {
46                continue;
47            }
48            let r = sqrt(-2.0 * log(u1));
49            let theta = 2.0 * core::f64::consts::PI * u2;
50            let z0 = r * cos(theta);
51            let z1 = r * sin(theta);
52            self.spare = Some(z1);
53            return z0;
54        }
55    }
56
57    /// Return a uniform `f64` in `[0, 1)` using 53 bits of mantissa.
58    fn next_uniform(&mut self) -> f64 {
59        let n = self.rng.next_u64();
60        // The top 53 bits are exactly representable in an `f64` mantissa,
61        // so the cast is lossless; scale into `[0, 1)` by `2^-53`.
62        #[allow(clippy::cast_precision_loss)]
63        let numerator = (n >> 11) as f64;
64        // 2^53 is `9_007_199_254_740_992` — exact in `f64`.
65        numerator * (1.0_f64 / 9_007_199_254_740_992.0_f64)
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::ChaChaGaussianStream;
72
73    // `libm` 0.2.16 uses SSE2 inline assembly (`sqrtsd`) for `f64::sqrt` on
74    // x86_64. Miri's interpreter does not support inline assembly, so these
75    // three tests — which all call `ChaChaGaussianStream::next_f64` and
76    // therefore hit `libm::sqrt` via the Box-Muller transform — are skipped
77    // under Miri. They run normally under `cargo test`.
78    // See Phase 20 implementation notes §Miri / libm inline-asm gap.
79
80    #[test]
81    #[cfg_attr(miri, ignore)]
82    fn same_seed_produces_identical_stream() {
83        let mut a = ChaChaGaussianStream::new(42);
84        let mut b = ChaChaGaussianStream::new(42);
85        for _ in 0..64 {
86            assert_eq!(a.next_f64().to_bits(), b.next_f64().to_bits());
87        }
88    }
89
90    #[test]
91    #[cfg_attr(miri, ignore)]
92    fn different_seeds_diverge() {
93        let mut a = ChaChaGaussianStream::new(42);
94        let mut b = ChaChaGaussianStream::new(43);
95        let mut diffs = 0;
96        for _ in 0..64 {
97            if a.next_f64().to_bits() != b.next_f64().to_bits() {
98                diffs += 1;
99            }
100        }
101        assert!(diffs > 0, "different seeds must not collide bit-for-bit");
102    }
103
104    #[test]
105    #[cfg_attr(miri, ignore)]
106    fn samples_have_reasonable_spread() {
107        let mut s = ChaChaGaussianStream::new(0);
108        let mut sum = 0.0f64;
109        let mut sum_sq = 0.0f64;
110        let n = 2048;
111        for _ in 0..n {
112            let x = s.next_f64();
113            sum += x;
114            sum_sq += x * x;
115        }
116        let mean = sum / f64::from(n);
117        let var = sum_sq / f64::from(n) - mean * mean;
118        assert!(mean.abs() < 0.1, "sample mean too far from 0: {mean}");
119        assert!(
120            (var - 1.0).abs() < 0.2,
121            "sample variance too far from 1: {var}"
122        );
123    }
124}