ruvector_rabitq/rotation.rs
1//! Random orthogonal rotation.
2//!
3//! Two flavours are supported:
4//!
5//! * `HaarDense` — Haar-uniform `D×D` matrix built via Gram–Schmidt on an
6//! i.i.d. Gaussian block. `apply` is `O(D²)`; storage is `4·D²` bytes. This
7//! is the default and stays bit-identical to previous snapshots.
8//!
9//! * `HadamardSigned` — randomised Hadamard rotation `D₁·H·D₂·H·D₃` where
10//! each `Dᵢ` is a ±1 diagonal and `H` is the Fast Walsh–Hadamard Transform.
11//! Cost is `O(D log D)` with no matrix stored (just `3·D` signs). TurboQuant
12//! (arXiv:2504.19874 §3.2) shows this hits the "close to Haar-uniform"
13//! regime RaBitQ needs for its Johnson–Lindenstrauss-style error bound.
14//!
15//! For arbitrary `dim` the Hadamard flavour zero-pads up to the next power of
16//! two, runs the butterfly there, then truncates back to `dim` — standard
17//! FWHT-on-non-dyadic trick.
18
19use rand::{Rng, SeedableRng};
20use rand_distr::{Distribution, StandardNormal};
21
22/// Which random-rotation construction a `RandomRotation` is backed by.
23#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
24pub enum RandomRotationKind {
25 /// Dense `D×D` Haar-uniform orthogonal matrix.
26 HaarDense,
27 /// Randomised Hadamard: three random ±1 diagonals interleaved with FWHT.
28 HadamardSigned,
29}
30
31/// Internal storage mode. Kept private so we can evolve it without breaking
32/// callers — users interact via `apply` / `apply_into` / `bytes` / `kind`.
33#[derive(Clone, serde::Serialize, serde::Deserialize)]
34enum Mode {
35 /// Flattened row-major `D×D` matrix.
36 HaarDense { matrix: Vec<f32> },
37 /// Three ±1 sign vectors of length `padded_dim`, applied as `D₁·H·D₂·H·D₃`.
38 HadamardSigned {
39 signs: [Vec<f32>; 3],
40 padded_dim: usize,
41 },
42}
43
44/// A random (approximately) orthogonal rotation.
45///
46/// Build once, apply many times. The default constructor `random` yields a
47/// Haar-uniform `D×D` matrix for backward compatibility; `hadamard` opts in
48/// to the `O(D log D)` HD-HD-HD variant.
49#[derive(Clone, serde::Serialize, serde::Deserialize)]
50pub struct RandomRotation {
51 mode: Mode,
52 pub dim: usize,
53 /// Kept for backward compatibility with snapshots that accessed the raw
54 /// matrix. Populated only for `HaarDense`; empty for Hadamard.
55 #[serde(default)]
56 pub matrix: Vec<f32>,
57}
58
59impl RandomRotation {
60 /// Sample a Haar-uniform orthogonal matrix of size `dim × dim`.
61 ///
62 /// Backward-compatible default: existing callers that expect a dense
63 /// matrix under `self.matrix` keep working unchanged.
64 pub fn random(dim: usize, seed: u64) -> Self {
65 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
66 // Fill a dim×dim matrix with N(0,1) entries.
67 let mut m: Vec<Vec<f32>> = (0..dim)
68 .map(|_| {
69 (0..dim)
70 .map(|_| {
71 <StandardNormal as Distribution<f64>>::sample(&StandardNormal, &mut rng)
72 as f32
73 })
74 .collect()
75 })
76 .collect();
77
78 // Gram–Schmidt orthonormalisation (in-place).
79 for i in 0..dim {
80 // Subtract projections of all previous basis vectors.
81 for j in 0..i {
82 let dot: f32 = (0..dim).map(|k| m[i][k] * m[j][k]).sum();
83 for k in 0..dim {
84 let v = m[j][k];
85 m[i][k] -= dot * v;
86 }
87 }
88 // Normalise.
89 let norm: f32 = m[i].iter().map(|&x| x * x).sum::<f32>().sqrt();
90 if norm > 1e-10 {
91 m[i].iter_mut().for_each(|x| *x /= norm);
92 }
93 }
94
95 let matrix: Vec<f32> = m.into_iter().flatten().collect();
96 Self {
97 mode: Mode::HaarDense {
98 matrix: matrix.clone(),
99 },
100 dim,
101 matrix,
102 }
103 }
104
105 /// Construct a randomised Hadamard rotation `D₁·H·D₂·H·D₃`.
106 ///
107 /// Stores only `3 × padded_dim` ±1 entries — no matrix materialised.
108 /// `padded_dim` is the next power of two `≥ dim`; for dyadic `dim` it
109 /// equals `dim`.
110 pub fn hadamard(dim: usize, seed: u64) -> Self {
111 assert!(dim > 0, "RandomRotation::hadamard: dim must be > 0");
112 let padded_dim = dim.next_power_of_two();
113 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
114 // Three independent ±1 sign vectors.
115 let make_signs = |rng: &mut rand::rngs::StdRng| -> Vec<f32> {
116 (0..padded_dim)
117 .map(|_| if rng.gen::<bool>() { 1.0_f32 } else { -1.0_f32 })
118 .collect()
119 };
120 let signs = [
121 make_signs(&mut rng),
122 make_signs(&mut rng),
123 make_signs(&mut rng),
124 ];
125
126 Self {
127 mode: Mode::HadamardSigned { signs, padded_dim },
128 dim,
129 matrix: Vec::new(),
130 }
131 }
132
133 /// Which construction backs this rotation.
134 #[inline]
135 pub fn kind(&self) -> RandomRotationKind {
136 match &self.mode {
137 Mode::HaarDense { .. } => RandomRotationKind::HaarDense,
138 Mode::HadamardSigned { .. } => RandomRotationKind::HadamardSigned,
139 }
140 }
141
142 /// Apply the rotation: out = P · v (length must equal dim).
143 #[inline]
144 pub fn apply(&self, v: &[f32]) -> Vec<f32> {
145 debug_assert_eq!(v.len(), self.dim);
146 let mut out = vec![0.0f32; self.dim];
147 self.apply_into(v, &mut out);
148 out
149 }
150
151 /// In-place variant that writes into a caller-provided buffer.
152 /// Callers doing many rotations (hot query path) should reuse one
153 /// `Vec<f32>` instead of allocating per call — saves one malloc
154 /// per query in the ANN index's `encode_query_packed` path.
155 #[inline]
156 pub fn apply_into(&self, v: &[f32], out: &mut [f32]) {
157 debug_assert_eq!(v.len(), self.dim);
158 debug_assert_eq!(out.len(), self.dim);
159 match &self.mode {
160 Mode::HaarDense { matrix } => {
161 let d = self.dim;
162 for (i, out_i) in out.iter_mut().enumerate() {
163 let row = &matrix[i * d..(i + 1) * d];
164 *out_i = row.iter().zip(v.iter()).map(|(&r, &x)| r * x).sum();
165 }
166 }
167 Mode::HadamardSigned { signs, padded_dim } => {
168 // Scratch buffer at padded size — zero-pad the tail.
169 let mut buf = vec![0.0_f32; *padded_dim];
170 buf[..self.dim].copy_from_slice(v);
171 // D₃
172 for (b, s) in buf.iter_mut().zip(signs[2].iter()) {
173 *b *= *s;
174 }
175 fwht_inplace(&mut buf);
176 // D₂
177 for (b, s) in buf.iter_mut().zip(signs[1].iter()) {
178 *b *= *s;
179 }
180 fwht_inplace(&mut buf);
181 // D₁
182 for (b, s) in buf.iter_mut().zip(signs[0].iter()) {
183 *b *= *s;
184 }
185 // Normalise: two FWHT passes multiply the norm by `padded_dim`
186 // (each H is orthogonal only after dividing by √padded_dim),
187 // so the combined scale factor is 1 / padded_dim.
188 let scale = 1.0_f32 / (*padded_dim as f32);
189 for (o, b) in out.iter_mut().zip(buf.iter().take(self.dim)) {
190 *o = b * scale;
191 }
192 }
193 }
194 }
195
196 /// Memory usage in bytes of the rotation's internal storage.
197 pub fn bytes(&self) -> usize {
198 match &self.mode {
199 Mode::HaarDense { matrix } => matrix.len() * 4,
200 Mode::HadamardSigned { signs, .. } => signs.iter().map(|s| s.len() * 4).sum::<usize>(),
201 }
202 }
203}
204
205/// Fast in-place L2 normalisation.
206pub fn normalize_inplace(v: &mut [f32]) {
207 let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
208 if norm > 1e-10 {
209 v.iter_mut().for_each(|x| *x /= norm);
210 }
211}
212
213/// In-place Fast Walsh–Hadamard Transform (unnormalised, natural order).
214///
215/// Requires `buf.len()` to be a power of two. Runs the iterative butterfly:
216/// at stage `h`, pairs of elements `(buf[i+j], buf[i+j+h])` are replaced by
217/// their sum and difference. After completion, `buf` holds `H · buf_in`
218/// where `H` is the unnormalised Hadamard matrix with `H Hᵀ = N · I`.
219#[inline]
220fn fwht_inplace(buf: &mut [f32]) {
221 let n = buf.len();
222 debug_assert!(n.is_power_of_two(), "FWHT requires power-of-two length");
223 let mut h = 1;
224 while h < n {
225 let mut i = 0;
226 while i < n {
227 for j in i..(i + h) {
228 let x = buf[j];
229 let y = buf[j + h];
230 buf[j] = x + y;
231 buf[j + h] = x - y;
232 }
233 i += h * 2;
234 }
235 h *= 2;
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use rand::rngs::StdRng;
243 use rand_distr::StandardNormal;
244
245 /// Full orthogonality check — every pair of rows must be orthonormal.
246 /// Stricter than the shipped version at `f2dbb6efb` which only tested
247 /// (row 0, row 1).
248 #[test]
249 fn orthogonality_all_pairs_d64() {
250 check_orthonormal(64, 42, 1e-4);
251 }
252
253 #[test]
254 fn orthogonality_all_pairs_d128() {
255 check_orthonormal(128, 7, 1e-4);
256 }
257
258 /// At D=256 classical Gram-Schmidt accumulates enough f32 round-off
259 /// that we widen the tolerance to 1e-3 — still tight enough for the
260 /// estimator not to drift but surfaces that GS is not numerically
261 /// stable at large D. Reminder to move to Householder / modified GS
262 /// when we start shipping D ≥ 1024.
263 #[test]
264 fn orthogonality_all_pairs_d256() {
265 check_orthonormal(256, 11, 1e-3);
266 }
267
268 fn check_orthonormal(dim: usize, seed: u64, tol: f32) {
269 let rot = RandomRotation::random(dim, seed);
270 let d = rot.dim;
271 for i in 0..d {
272 let ri = &rot.matrix[i * d..(i + 1) * d];
273 // Unit norm.
274 let ni: f32 = ri.iter().map(|&x| x * x).sum::<f32>().sqrt();
275 assert!((ni - 1.0).abs() < tol, "row {i} norm = {ni}, D={d}");
276 // Orthogonal to all later rows.
277 for j in (i + 1)..d {
278 let rj = &rot.matrix[j * d..(j + 1) * d];
279 let dot: f32 = ri.iter().zip(rj.iter()).map(|(&a, &b)| a * b).sum();
280 assert!(dot.abs() < tol, "rows {i},{j} dot={dot}, D={d}");
281 }
282 }
283 }
284
285 #[test]
286 fn apply_preserves_norm() {
287 let rot = RandomRotation::random(128, 7);
288 let v: Vec<f32> = (0..128_u32).map(|i| (i as f32).sin()).collect();
289 let rv = rot.apply(&v);
290 let norm_in: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
291 let norm_out: f32 = rv.iter().map(|&x| x * x).sum::<f32>().sqrt();
292 assert!((norm_in - norm_out).abs() / norm_in < 1e-3);
293 }
294
295 /// Determinism: same seed + same dim → bit-identical rotation matrix.
296 #[test]
297 fn seed_reproducibility() {
298 let a = RandomRotation::random(64, 1234);
299 let b = RandomRotation::random(64, 1234);
300 assert_eq!(a.matrix, b.matrix);
301 }
302
303 // ----- Randomised Hadamard (HD-HD-HD) tests --------------------------------
304
305 /// Sample random unit vectors via StdRng + StandardNormal (seeded → reproducible).
306 fn random_unit_vecs(dim: usize, n: usize, seed: u64) -> Vec<Vec<f32>> {
307 let mut rng = StdRng::seed_from_u64(seed);
308 (0..n)
309 .map(|_| {
310 let mut v: Vec<f32> = (0..dim)
311 .map(|_| {
312 <StandardNormal as Distribution<f64>>::sample(&StandardNormal, &mut rng)
313 as f32
314 })
315 .collect();
316 normalize_inplace(&mut v);
317 v
318 })
319 .collect()
320 }
321
322 fn hadamard_norm_check(dim: usize, seed: u64) {
323 let rot = RandomRotation::hadamard(dim, seed);
324 assert_eq!(rot.kind(), RandomRotationKind::HadamardSigned);
325 let vecs = random_unit_vecs(dim, 100, seed ^ 0xDEAD_BEEF);
326 for v in &vecs {
327 let rv = rot.apply(v);
328 let n: f32 = rv.iter().map(|&x| x * x).sum::<f32>().sqrt();
329 // Isotropy is approximate (truncation + padding break exact
330 // orthogonality) — loose ±5 % band keeps RaBitQ estimator safe.
331 assert!(
332 (0.95..=1.05).contains(&n),
333 "D={dim}: rotated unit vector has norm {n}",
334 );
335 }
336 }
337
338 /// D=128 and D=256 are powers of two — no padding path.
339 #[test]
340 fn hadamard_apply_preserves_norm_power_of_two() {
341 hadamard_norm_check(128, 7);
342 hadamard_norm_check(256, 11);
343 }
344
345 /// D=1000 exercises the zero-pad-to-1024 branch plus the truncation
346 /// back to `dim`. Looser isotropy is expected and allowed by the ±5 %
347 /// tolerance.
348 #[test]
349 fn hadamard_apply_preserves_norm_non_power_of_two() {
350 hadamard_norm_check(1000, 3);
351 }
352
353 /// Same seed → bit-identical output for both sign vectors (via apply).
354 #[test]
355 fn hadamard_is_deterministic() {
356 let a = RandomRotation::hadamard(128, 0xC0FFEE);
357 let b = RandomRotation::hadamard(128, 0xC0FFEE);
358 let v: Vec<f32> = (0..128_u32).map(|i| (i as f32).cos()).collect();
359 assert_eq!(a.apply(&v), b.apply(&v));
360 // Different seed must change the output.
361 let c = RandomRotation::hadamard(128, 0xC0FFEE + 1);
362 assert_ne!(a.apply(&v), c.apply(&v));
363 }
364
365 /// Correctness smoke: for a dyadic dim, the all-ones input after the
366 /// first FWHT collapses to `(N, 0, 0, …)` — a cheap way to verify the
367 /// butterfly without timing.
368 #[test]
369 fn hadamard_is_fast() {
370 // FWHT of `[1; 8]` must be `[8, 0, 0, 0, 0, 0, 0, 0]`.
371 let mut buf = vec![1.0_f32; 8];
372 fwht_inplace(&mut buf);
373 assert!((buf[0] - 8.0).abs() < 1e-6);
374 for v in &buf[1..] {
375 assert!(v.abs() < 1e-6);
376 }
377
378 // Storage footprint: Hadamard must be dramatically smaller than Haar
379 // at non-trivial dim (3·D floats vs D² floats).
380 let had = RandomRotation::hadamard(128, 1);
381 let haar = RandomRotation::random(128, 1);
382 assert!(had.bytes() < haar.bytes() / 10);
383 assert_eq!(had.kind(), RandomRotationKind::HadamardSigned);
384 assert_eq!(haar.kind(), RandomRotationKind::HaarDense);
385 }
386}