tinyquant_core/codec/gaussian.rs
1//! Deterministic ChaCha20-backed standard-normal f64 stream.
2//!
3//! This is the canonical Gaussian source for Rust-side rotation matrix
4//! generation. It uses `ChaCha20` (via `rand_chacha::ChaCha20Rng`) for the
5//! uniform stream and the polar Box-Muller transform to emit independent
6//! `N(0, 1)` samples in `f64`.
7//!
8//! The exact recipe is specified in
9//! `docs/design/rust/numerical-semantics.md` §R1 and matches the pseudo-
10//! code in `docs/plans/rust/phase-13-rotation-numerics.md`. Any change
11//! here invalidates every rotation fixture under
12//! `tests/fixtures/rotation/`.
13
14use libm::{cos, log, sin, sqrt};
15use rand_chacha::rand_core::{RngCore, SeedableRng};
16use rand_chacha::ChaCha20Rng;
17
18/// A deterministic standard-normal `f64` generator seeded from `u64`.
19///
20/// Box-Muller turns two uniform samples into two normals; the second
21/// normal is cached in `spare` and returned on the next call.
22pub(crate) struct ChaChaGaussianStream {
23 rng: ChaCha20Rng,
24 spare: Option<f64>,
25}
26
27impl ChaChaGaussianStream {
28 /// Seed a new stream using `ChaCha20Rng::seed_from_u64(seed)`.
29 pub fn new(seed: u64) -> Self {
30 Self {
31 rng: ChaCha20Rng::seed_from_u64(seed),
32 spare: None,
33 }
34 }
35
36 /// Return the next standard-normal sample.
37 pub fn next_f64(&mut self) -> f64 {
38 if let Some(spare) = self.spare.take() {
39 return spare;
40 }
41 // Polar (classical) Box-Muller: reject u1 == 0 to avoid ln(0).
42 loop {
43 let u1 = self.next_uniform();
44 let u2 = self.next_uniform();
45 if u1 <= 0.0 {
46 continue;
47 }
48 let r = sqrt(-2.0 * log(u1));
49 let theta = 2.0 * core::f64::consts::PI * u2;
50 let z0 = r * cos(theta);
51 let z1 = r * sin(theta);
52 self.spare = Some(z1);
53 return z0;
54 }
55 }
56
57 /// Return a uniform `f64` in `[0, 1)` using 53 bits of mantissa.
58 fn next_uniform(&mut self) -> f64 {
59 let n = self.rng.next_u64();
60 // The top 53 bits are exactly representable in an `f64` mantissa,
61 // so the cast is lossless; scale into `[0, 1)` by `2^-53`.
62 #[allow(clippy::cast_precision_loss)]
63 let numerator = (n >> 11) as f64;
64 // 2^53 is `9_007_199_254_740_992` — exact in `f64`.
65 numerator * (1.0_f64 / 9_007_199_254_740_992.0_f64)
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::ChaChaGaussianStream;
72
73 // `libm` 0.2.16 uses SSE2 inline assembly (`sqrtsd`) for `f64::sqrt` on
74 // x86_64. Miri's interpreter does not support inline assembly, so these
75 // three tests — which all call `ChaChaGaussianStream::next_f64` and
76 // therefore hit `libm::sqrt` via the Box-Muller transform — are skipped
77 // under Miri. They run normally under `cargo test`.
78 // See Phase 20 implementation notes §Miri / libm inline-asm gap.
79
80 #[test]
81 #[cfg_attr(miri, ignore)]
82 fn same_seed_produces_identical_stream() {
83 let mut a = ChaChaGaussianStream::new(42);
84 let mut b = ChaChaGaussianStream::new(42);
85 for _ in 0..64 {
86 assert_eq!(a.next_f64().to_bits(), b.next_f64().to_bits());
87 }
88 }
89
90 #[test]
91 #[cfg_attr(miri, ignore)]
92 fn different_seeds_diverge() {
93 let mut a = ChaChaGaussianStream::new(42);
94 let mut b = ChaChaGaussianStream::new(43);
95 let mut diffs = 0;
96 for _ in 0..64 {
97 if a.next_f64().to_bits() != b.next_f64().to_bits() {
98 diffs += 1;
99 }
100 }
101 assert!(diffs > 0, "different seeds must not collide bit-for-bit");
102 }
103
104 #[test]
105 #[cfg_attr(miri, ignore)]
106 fn samples_have_reasonable_spread() {
107 let mut s = ChaChaGaussianStream::new(0);
108 let mut sum = 0.0f64;
109 let mut sum_sq = 0.0f64;
110 let n = 2048;
111 for _ in 0..n {
112 let x = s.next_f64();
113 sum += x;
114 sum_sq += x * x;
115 }
116 let mean = sum / f64::from(n);
117 let var = sum_sq / f64::from(n) - mean * mean;
118 assert!(mean.abs() < 0.1, "sample mean too far from 0: {mean}");
119 assert!(
120 (var - 1.0).abs() < 0.2,
121 "sample variance too far from 1: {var}"
122 );
123 }
124}