Skip to main content

ruvector_rabitq/
rotation.rs

1//! Random orthogonal rotation.
2//!
3//! Two flavours are supported:
4//!
5//! * `HaarDense` — Haar-uniform `D×D` matrix built via Gram–Schmidt on an
6//!   i.i.d. Gaussian block. `apply` is `O(D²)`; storage is `4·D²` bytes. This
7//!   is the default and stays bit-identical to previous snapshots.
8//!
9//! * `HadamardSigned` — randomised Hadamard rotation `D₁·H·D₂·H·D₃` where
10//!   each `Dᵢ` is a ±1 diagonal and `H` is the Fast Walsh–Hadamard Transform.
11//!   Cost is `O(D log D)` with no matrix stored (just `3·D` signs). TurboQuant
12//!   (arXiv:2504.19874 §3.2) shows this hits the "close to Haar-uniform"
13//!   regime RaBitQ needs for its Johnson–Lindenstrauss-style error bound.
14//!
15//! For arbitrary `dim` the Hadamard flavour zero-pads up to the next power of
16//! two, runs the butterfly there, then truncates back to `dim` — standard
17//! FWHT-on-non-dyadic trick.
18
19use rand::{Rng, SeedableRng};
20use rand_distr::{Distribution, StandardNormal};
21
22/// Which random-rotation construction a `RandomRotation` is backed by.
23#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
24pub enum RandomRotationKind {
25    /// Dense `D×D` Haar-uniform orthogonal matrix.
26    HaarDense,
27    /// Randomised Hadamard: three random ±1 diagonals interleaved with FWHT.
28    HadamardSigned,
29}
30
31/// Internal storage mode. Kept private so we can evolve it without breaking
32/// callers — users interact via `apply` / `apply_into` / `bytes` / `kind`.
33#[derive(Clone, serde::Serialize, serde::Deserialize)]
34enum Mode {
35    /// Flattened row-major `D×D` matrix.
36    HaarDense { matrix: Vec<f32> },
37    /// Three ±1 sign vectors of length `padded_dim`, applied as `D₁·H·D₂·H·D₃`.
38    HadamardSigned {
39        signs: [Vec<f32>; 3],
40        padded_dim: usize,
41    },
42}
43
44/// A random (approximately) orthogonal rotation.
45///
46/// Build once, apply many times. The default constructor `random` yields a
47/// Haar-uniform `D×D` matrix for backward compatibility; `hadamard` opts in
48/// to the `O(D log D)` HD-HD-HD variant.
49#[derive(Clone, serde::Serialize, serde::Deserialize)]
50pub struct RandomRotation {
51    mode: Mode,
52    pub dim: usize,
53    /// Kept for backward compatibility with snapshots that accessed the raw
54    /// matrix. Populated only for `HaarDense`; empty for Hadamard.
55    #[serde(default)]
56    pub matrix: Vec<f32>,
57}
58
59impl RandomRotation {
60    /// Sample a Haar-uniform orthogonal matrix of size `dim × dim`.
61    ///
62    /// Backward-compatible default: existing callers that expect a dense
63    /// matrix under `self.matrix` keep working unchanged.
64    pub fn random(dim: usize, seed: u64) -> Self {
65        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
66        // Fill a dim×dim matrix with N(0,1) entries.
67        let mut m: Vec<Vec<f32>> = (0..dim)
68            .map(|_| {
69                (0..dim)
70                    .map(|_| {
71                        <StandardNormal as Distribution<f64>>::sample(&StandardNormal, &mut rng)
72                            as f32
73                    })
74                    .collect()
75            })
76            .collect();
77
78        // Gram–Schmidt orthonormalisation (in-place).
79        for i in 0..dim {
80            // Subtract projections of all previous basis vectors.
81            for j in 0..i {
82                let dot: f32 = (0..dim).map(|k| m[i][k] * m[j][k]).sum();
83                for k in 0..dim {
84                    let v = m[j][k];
85                    m[i][k] -= dot * v;
86                }
87            }
88            // Normalise.
89            let norm: f32 = m[i].iter().map(|&x| x * x).sum::<f32>().sqrt();
90            if norm > 1e-10 {
91                m[i].iter_mut().for_each(|x| *x /= norm);
92            }
93        }
94
95        let matrix: Vec<f32> = m.into_iter().flatten().collect();
96        Self {
97            mode: Mode::HaarDense {
98                matrix: matrix.clone(),
99            },
100            dim,
101            matrix,
102        }
103    }
104
105    /// Construct a randomised Hadamard rotation `D₁·H·D₂·H·D₃`.
106    ///
107    /// Stores only `3 × padded_dim` ±1 entries — no matrix materialised.
108    /// `padded_dim` is the next power of two `≥ dim`; for dyadic `dim` it
109    /// equals `dim`.
110    pub fn hadamard(dim: usize, seed: u64) -> Self {
111        assert!(dim > 0, "RandomRotation::hadamard: dim must be > 0");
112        let padded_dim = dim.next_power_of_two();
113        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
114        // Three independent ±1 sign vectors.
115        let make_signs = |rng: &mut rand::rngs::StdRng| -> Vec<f32> {
116            (0..padded_dim)
117                .map(|_| if rng.gen::<bool>() { 1.0_f32 } else { -1.0_f32 })
118                .collect()
119        };
120        let signs = [
121            make_signs(&mut rng),
122            make_signs(&mut rng),
123            make_signs(&mut rng),
124        ];
125
126        Self {
127            mode: Mode::HadamardSigned { signs, padded_dim },
128            dim,
129            matrix: Vec::new(),
130        }
131    }
132
133    /// Which construction backs this rotation.
134    #[inline]
135    pub fn kind(&self) -> RandomRotationKind {
136        match &self.mode {
137            Mode::HaarDense { .. } => RandomRotationKind::HaarDense,
138            Mode::HadamardSigned { .. } => RandomRotationKind::HadamardSigned,
139        }
140    }
141
142    /// Apply the rotation: out = P · v  (length must equal dim).
143    #[inline]
144    pub fn apply(&self, v: &[f32]) -> Vec<f32> {
145        debug_assert_eq!(v.len(), self.dim);
146        let mut out = vec![0.0f32; self.dim];
147        self.apply_into(v, &mut out);
148        out
149    }
150
151    /// In-place variant that writes into a caller-provided buffer.
152    /// Callers doing many rotations (hot query path) should reuse one
153    /// `Vec<f32>` instead of allocating per call — saves one malloc
154    /// per query in the ANN index's `encode_query_packed` path.
155    #[inline]
156    pub fn apply_into(&self, v: &[f32], out: &mut [f32]) {
157        debug_assert_eq!(v.len(), self.dim);
158        debug_assert_eq!(out.len(), self.dim);
159        match &self.mode {
160            Mode::HaarDense { matrix } => {
161                let d = self.dim;
162                for (i, out_i) in out.iter_mut().enumerate() {
163                    let row = &matrix[i * d..(i + 1) * d];
164                    *out_i = row.iter().zip(v.iter()).map(|(&r, &x)| r * x).sum();
165                }
166            }
167            Mode::HadamardSigned { signs, padded_dim } => {
168                // Scratch buffer at padded size — zero-pad the tail.
169                let mut buf = vec![0.0_f32; *padded_dim];
170                buf[..self.dim].copy_from_slice(v);
171                // D₃
172                for (b, s) in buf.iter_mut().zip(signs[2].iter()) {
173                    *b *= *s;
174                }
175                fwht_inplace(&mut buf);
176                // D₂
177                for (b, s) in buf.iter_mut().zip(signs[1].iter()) {
178                    *b *= *s;
179                }
180                fwht_inplace(&mut buf);
181                // D₁
182                for (b, s) in buf.iter_mut().zip(signs[0].iter()) {
183                    *b *= *s;
184                }
185                // Normalise: two FWHT passes multiply the norm by `padded_dim`
186                // (each H is orthogonal only after dividing by √padded_dim),
187                // so the combined scale factor is 1 / padded_dim.
188                let scale = 1.0_f32 / (*padded_dim as f32);
189                for (o, b) in out.iter_mut().zip(buf.iter().take(self.dim)) {
190                    *o = b * scale;
191                }
192            }
193        }
194    }
195
196    /// Memory usage in bytes of the rotation's internal storage.
197    pub fn bytes(&self) -> usize {
198        match &self.mode {
199            Mode::HaarDense { matrix } => matrix.len() * 4,
200            Mode::HadamardSigned { signs, .. } => signs.iter().map(|s| s.len() * 4).sum::<usize>(),
201        }
202    }
203}
204
205/// Fast in-place L2 normalisation.
206pub fn normalize_inplace(v: &mut [f32]) {
207    let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
208    if norm > 1e-10 {
209        v.iter_mut().for_each(|x| *x /= norm);
210    }
211}
212
213/// In-place Fast Walsh–Hadamard Transform (unnormalised, natural order).
214///
215/// Requires `buf.len()` to be a power of two. Runs the iterative butterfly:
216/// at stage `h`, pairs of elements `(buf[i+j], buf[i+j+h])` are replaced by
217/// their sum and difference. After completion, `buf` holds `H · buf_in`
218/// where `H` is the unnormalised Hadamard matrix with `H Hᵀ = N · I`.
219#[inline]
220fn fwht_inplace(buf: &mut [f32]) {
221    let n = buf.len();
222    debug_assert!(n.is_power_of_two(), "FWHT requires power-of-two length");
223    let mut h = 1;
224    while h < n {
225        let mut i = 0;
226        while i < n {
227            for j in i..(i + h) {
228                let x = buf[j];
229                let y = buf[j + h];
230                buf[j] = x + y;
231                buf[j + h] = x - y;
232            }
233            i += h * 2;
234        }
235        h *= 2;
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use rand::rngs::StdRng;
243    use rand_distr::StandardNormal;
244
245    /// Full orthogonality check — every pair of rows must be orthonormal.
246    /// Stricter than the shipped version at `f2dbb6efb` which only tested
247    /// (row 0, row 1).
248    #[test]
249    fn orthogonality_all_pairs_d64() {
250        check_orthonormal(64, 42, 1e-4);
251    }
252
253    #[test]
254    fn orthogonality_all_pairs_d128() {
255        check_orthonormal(128, 7, 1e-4);
256    }
257
258    /// At D=256 classical Gram-Schmidt accumulates enough f32 round-off
259    /// that we widen the tolerance to 1e-3 — still tight enough for the
260    /// estimator not to drift but surfaces that GS is not numerically
261    /// stable at large D. Reminder to move to Householder / modified GS
262    /// when we start shipping D ≥ 1024.
263    #[test]
264    fn orthogonality_all_pairs_d256() {
265        check_orthonormal(256, 11, 1e-3);
266    }
267
268    fn check_orthonormal(dim: usize, seed: u64, tol: f32) {
269        let rot = RandomRotation::random(dim, seed);
270        let d = rot.dim;
271        for i in 0..d {
272            let ri = &rot.matrix[i * d..(i + 1) * d];
273            // Unit norm.
274            let ni: f32 = ri.iter().map(|&x| x * x).sum::<f32>().sqrt();
275            assert!((ni - 1.0).abs() < tol, "row {i} norm = {ni}, D={d}");
276            // Orthogonal to all later rows.
277            for j in (i + 1)..d {
278                let rj = &rot.matrix[j * d..(j + 1) * d];
279                let dot: f32 = ri.iter().zip(rj.iter()).map(|(&a, &b)| a * b).sum();
280                assert!(dot.abs() < tol, "rows {i},{j} dot={dot}, D={d}");
281            }
282        }
283    }
284
285    #[test]
286    fn apply_preserves_norm() {
287        let rot = RandomRotation::random(128, 7);
288        let v: Vec<f32> = (0..128_u32).map(|i| (i as f32).sin()).collect();
289        let rv = rot.apply(&v);
290        let norm_in: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
291        let norm_out: f32 = rv.iter().map(|&x| x * x).sum::<f32>().sqrt();
292        assert!((norm_in - norm_out).abs() / norm_in < 1e-3);
293    }
294
295    /// Determinism: same seed + same dim → bit-identical rotation matrix.
296    #[test]
297    fn seed_reproducibility() {
298        let a = RandomRotation::random(64, 1234);
299        let b = RandomRotation::random(64, 1234);
300        assert_eq!(a.matrix, b.matrix);
301    }
302
303    // ----- Randomised Hadamard (HD-HD-HD) tests --------------------------------
304
305    /// Sample random unit vectors via StdRng + StandardNormal (seeded → reproducible).
306    fn random_unit_vecs(dim: usize, n: usize, seed: u64) -> Vec<Vec<f32>> {
307        let mut rng = StdRng::seed_from_u64(seed);
308        (0..n)
309            .map(|_| {
310                let mut v: Vec<f32> = (0..dim)
311                    .map(|_| {
312                        <StandardNormal as Distribution<f64>>::sample(&StandardNormal, &mut rng)
313                            as f32
314                    })
315                    .collect();
316                normalize_inplace(&mut v);
317                v
318            })
319            .collect()
320    }
321
322    fn hadamard_norm_check(dim: usize, seed: u64) {
323        let rot = RandomRotation::hadamard(dim, seed);
324        assert_eq!(rot.kind(), RandomRotationKind::HadamardSigned);
325        let vecs = random_unit_vecs(dim, 100, seed ^ 0xDEAD_BEEF);
326        for v in &vecs {
327            let rv = rot.apply(v);
328            let n: f32 = rv.iter().map(|&x| x * x).sum::<f32>().sqrt();
329            // Isotropy is approximate (truncation + padding break exact
330            // orthogonality) — loose ±5 % band keeps RaBitQ estimator safe.
331            assert!(
332                (0.95..=1.05).contains(&n),
333                "D={dim}: rotated unit vector has norm {n}",
334            );
335        }
336    }
337
338    /// D=128 and D=256 are powers of two — no padding path.
339    #[test]
340    fn hadamard_apply_preserves_norm_power_of_two() {
341        hadamard_norm_check(128, 7);
342        hadamard_norm_check(256, 11);
343    }
344
345    /// D=1000 exercises the zero-pad-to-1024 branch plus the truncation
346    /// back to `dim`. Looser isotropy is expected and allowed by the ±5 %
347    /// tolerance.
348    #[test]
349    fn hadamard_apply_preserves_norm_non_power_of_two() {
350        hadamard_norm_check(1000, 3);
351    }
352
353    /// Same seed → bit-identical output for both sign vectors (via apply).
354    #[test]
355    fn hadamard_is_deterministic() {
356        let a = RandomRotation::hadamard(128, 0xC0FFEE);
357        let b = RandomRotation::hadamard(128, 0xC0FFEE);
358        let v: Vec<f32> = (0..128_u32).map(|i| (i as f32).cos()).collect();
359        assert_eq!(a.apply(&v), b.apply(&v));
360        // Different seed must change the output.
361        let c = RandomRotation::hadamard(128, 0xC0FFEE + 1);
362        assert_ne!(a.apply(&v), c.apply(&v));
363    }
364
365    /// Correctness smoke: for a dyadic dim, the all-ones input after the
366    /// first FWHT collapses to `(N, 0, 0, …)` — a cheap way to verify the
367    /// butterfly without timing.
368    #[test]
369    fn hadamard_is_fast() {
370        // FWHT of `[1; 8]` must be `[8, 0, 0, 0, 0, 0, 0, 0]`.
371        let mut buf = vec![1.0_f32; 8];
372        fwht_inplace(&mut buf);
373        assert!((buf[0] - 8.0).abs() < 1e-6);
374        for v in &buf[1..] {
375            assert!(v.abs() < 1e-6);
376        }
377
378        // Storage footprint: Hadamard must be dramatically smaller than Haar
379        // at non-trivial dim (3·D floats vs D² floats).
380        let had = RandomRotation::hadamard(128, 1);
381        let haar = RandomRotation::random(128, 1);
382        assert!(had.bytes() < haar.bytes() / 10);
383        assert_eq!(had.kind(), RandomRotationKind::HadamardSigned);
384        assert_eq!(haar.kind(), RandomRotationKind::HaarDense);
385    }
386}