Skip to main content

turbo_quant/
rotation.rs

1//! Random rotation matrices for whitening high-dimensional vectors before quantization.
2//!
3//! Rotating the data before quantization simplifies the geometry, making uniform
4//! scalar quantizers easier to use as deterministic baselines.
5//!
6//! Two implementations are provided:
7//! - [`StoredRotation`]: full d×d orthogonal matrix via QR decomposition. Correct
8//!   for any dimension but uses O(d²) memory (~9MB at d=1536).
9//! - [`FastHadamardRotation`]: deterministic sign rotation followed by a
10//!   normalized Hadamard transform for power-of-two dimensions.
11
12use crate::error::{Result, TurboQuantError};
13use nalgebra::DMatrix;
14use rand::{Rng, SeedableRng};
15use rand_chacha::ChaCha8Rng;
16use rand_distr::{Distribution, StandardNormal};
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19
20/// Rotation selection policy.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
22pub enum RotationKind {
23    /// Use FastHadamard for supported dimensions, otherwise Stored QR.
24    Auto,
25    /// Use deterministic Hadamard/SRHT-style rotation. Requires power-of-two dimensions.
26    FastHadamard,
27    /// Use dense QR reference rotation.
28    StoredQr,
29}
30
31impl RotationKind {
32    pub fn label(self) -> &'static str {
33        match self {
34            Self::Auto => "auto",
35            Self::FastHadamard => "fast_hadamard",
36            Self::StoredQr => "stored_qr_reference",
37        }
38    }
39}
40
41/// Concrete rotation backend used by quantizers.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub enum RotationBackend {
44    FastHadamard(FastHadamardRotation),
45    StoredQr(StoredRotation),
46}
47
48impl RotationBackend {
49    pub fn new(dim: usize, seed: u64, kind: RotationKind) -> Result<Self> {
50        match kind {
51            RotationKind::Auto if dim.is_power_of_two() => {
52                FastHadamardRotation::new(dim, seed).map(Self::FastHadamard)
53            }
54            RotationKind::Auto => StoredRotation::new(dim, seed).map(Self::StoredQr),
55            RotationKind::FastHadamard => {
56                FastHadamardRotation::new(dim, seed).map(Self::FastHadamard)
57            }
58            RotationKind::StoredQr => StoredRotation::new(dim, seed).map(Self::StoredQr),
59        }
60    }
61
62    pub fn kind(&self) -> RotationKind {
63        match self {
64            Self::FastHadamard(_) => RotationKind::FastHadamard,
65            Self::StoredQr(_) => RotationKind::StoredQr,
66        }
67    }
68
69    pub fn kind_label(&self) -> &'static str {
70        self.kind().label()
71    }
72
73    pub fn seed(&self) -> u64 {
74        match self {
75            Self::FastHadamard(rotation) => rotation.seed(),
76            Self::StoredQr(rotation) => rotation.seed(),
77        }
78    }
79}
80
81impl Rotation for RotationBackend {
82    fn dim(&self) -> usize {
83        match self {
84            Self::FastHadamard(rotation) => rotation.dim(),
85            Self::StoredQr(rotation) => rotation.dim(),
86        }
87    }
88
89    fn apply(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
90        match self {
91            Self::FastHadamard(rotation) => rotation.apply(input, output),
92            Self::StoredQr(rotation) => rotation.apply(input, output),
93        }
94    }
95
96    fn apply_inverse(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
97        match self {
98            Self::FastHadamard(rotation) => rotation.apply_inverse(input, output),
99            Self::StoredQr(rotation) => rotation.apply_inverse(input, output),
100        }
101    }
102}
103
104impl RotationBackend {
105    /// Apply the inverse rotation to a batch of `dim`-sized slices in one
106    /// call. For `FastHadamard` this is the same per-vector math as
107    /// `apply_inverse` but amortized across the whole batch. For
108    /// `StoredQr` the d×d matrix is converted to a row-major `Vec<f32>`
109    /// once and reused across the batch.
110    pub fn apply_inverse_batch(&self, inputs: &[&[f32]]) -> Result<Vec<Vec<f32>>> {
111        match self {
112            Self::FastHadamard(rotation) => rotation.apply_inverse_batch(inputs),
113            Self::StoredQr(rotation) => rotation.apply_inverse_batch(inputs),
114        }
115    }
116}
117
118/// A rotation that can be applied to and inverted on vectors of a fixed dimension.
119pub trait Rotation: Send + Sync {
120    /// The dimension this rotation operates on.
121    fn dim(&self) -> usize;
122
123    /// Apply the rotation: y = R · x.
124    ///
125    /// `input` and `output` must both have length `dim()`.
126    fn apply(&self, input: &[f32], output: &mut [f32]) -> Result<()>;
127
128    /// Apply the inverse (transpose) rotation: x = Rᵀ · y.
129    ///
130    /// For orthogonal matrices, the inverse equals the transpose.
131    fn apply_inverse(&self, input: &[f32], output: &mut [f32]) -> Result<()>;
132}
133
134/// Deterministic Hadamard/SRHT-style rotation for power-of-two dimensions.
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct FastHadamardRotation {
137    dim: usize,
138    seed: u64,
139    signs: Vec<f32>,
140}
141
142impl FastHadamardRotation {
143    pub fn new(dim: usize, seed: u64) -> Result<Self> {
144        if dim == 0 {
145            return Err(TurboQuantError::ZeroDimension);
146        }
147        if !dim.is_power_of_two() {
148            return Err(TurboQuantError::RotationFailed {
149                reason: format!("Hadamard rotation requires a power-of-two dimension, got {dim}"),
150            });
151        }
152        let mut rng = ChaCha8Rng::seed_from_u64(seed.wrapping_add(0xA11C_E55E_D5A5_EED5));
153        let signs = (0..dim)
154            .map(|_| if rng.gen::<bool>() { 1.0 } else { -1.0 })
155            .collect();
156        Ok(Self { dim, seed, signs })
157    }
158
159    pub fn seed(&self) -> u64 {
160        self.seed
161    }
162
163    /// Apply the inverse rotation to a batch of `dim`-sized slices in one
164    /// call. For each input slice: copy it into a freshly-allocated
165    /// `dim`-sized output, run the normalized FWHT, then multiply by the
166    /// sign vector. This is bit-exact identical to calling
167    /// `apply_inverse` N times in a loop; the win is amortizing the
168    /// per-call branch/lookup overhead and keeping `scale`, the
169    /// butterfly indices, and the `signs` table hot in cache.
170    pub fn apply_inverse_batch(&self, inputs: &[&[f32]]) -> Result<Vec<Vec<f32>>> {
171        if inputs.is_empty() {
172            return Ok(Vec::new());
173        }
174        let dim = self.dim;
175        let signs = &self.signs;
176        let mut outputs: Vec<Vec<f32>> = Vec::with_capacity(inputs.len());
177        for input in inputs {
178            if input.len() != dim {
179                return Err(TurboQuantError::DimensionMismatch {
180                    expected: dim,
181                    got: input.len(),
182                });
183            }
184            let mut out = vec![0.0f32; dim];
185            // Match `apply_inverse` byte-for-byte: copy, fwht, sign-flip.
186            out.copy_from_slice(input);
187            fwht_normalized(&mut out);
188            for (out_val, sign) in out.iter_mut().zip(signs.iter()) {
189                *out_val *= *sign;
190            }
191            outputs.push(out);
192        }
193        Ok(outputs)
194    }
195}
196
197impl Rotation for FastHadamardRotation {
198    fn dim(&self) -> usize {
199        self.dim
200    }
201
202    fn apply(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
203        check_dim(input.len(), self.dim)?;
204        check_dim(output.len(), self.dim)?;
205        for ((out, value), sign) in output.iter_mut().zip(input.iter()).zip(self.signs.iter()) {
206            *out = value * sign;
207        }
208        fwht_normalized(output);
209        Ok(())
210    }
211
212    fn apply_inverse(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
213        check_dim(input.len(), self.dim)?;
214        check_dim(output.len(), self.dim)?;
215        output.copy_from_slice(input);
216        fwht_normalized(output);
217        for (out, sign) in output.iter_mut().zip(self.signs.iter()) {
218            *out *= sign;
219        }
220        Ok(())
221    }
222}
223
224/// A full d×d orthogonal rotation matrix generated via QR decomposition of a
225/// random Gaussian matrix.
226///
227/// Seeded deterministically so that quantizer state can be serialized and
228/// reconstructed without storing the matrix itself — only the seed and dimension
229/// need to be persisted.
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct StoredRotation {
232    dim: usize,
233    seed: u64,
234    /// Row-major flat storage of the d×d orthogonal matrix.
235    #[serde(with = "matrix_serde")]
236    matrix: DMatrix<f32>,
237}
238
239impl StoredRotation {
240    /// Generate a new rotation for vectors of dimension `dim` using `seed`.
241    ///
242    /// The same `(dim, seed)` pair always produces the same rotation, making
243    /// this suitable for deterministic, reproducible compression pipelines.
244    pub fn new(dim: usize, seed: u64) -> Result<Self> {
245        if dim == 0 {
246            return Err(TurboQuantError::ZeroDimension);
247        }
248
249        let matrix = generate_orthogonal(dim, seed)?;
250        Ok(Self { dim, seed, matrix })
251    }
252
253    /// The seed used to generate this rotation.
254    pub fn seed(&self) -> u64 {
255        self.seed
256    }
257
258    /// Approximate memory used by the stored matrix in bytes.
259    pub fn memory_bytes(&self) -> usize {
260        self.dim * self.dim * std::mem::size_of::<f32>()
261    }
262}
263
264impl Rotation for StoredRotation {
265    fn dim(&self) -> usize {
266        self.dim
267    }
268
269    fn apply(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
270        check_dim(input.len(), self.dim)?;
271        check_dim(output.len(), self.dim)?;
272
273        // y = R · x  (matrix is stored column-major in nalgebra)
274        for (i, out) in output.iter_mut().enumerate() {
275            *out = self
276                .matrix
277                .row(i)
278                .iter()
279                .zip(input)
280                .map(|(r, x)| r * x)
281                .sum();
282        }
283        Ok(())
284    }
285
286    fn apply_inverse(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
287        check_dim(input.len(), self.dim)?;
288        check_dim(output.len(), self.dim)?;
289
290        // x = Rᵀ · y  — for orthogonal R, R⁻¹ = Rᵀ
291        for (i, out) in output.iter_mut().enumerate() {
292            *out = self
293                .matrix
294                .column(i)
295                .iter()
296                .zip(input)
297                .map(|(r, y)| r * y)
298                .sum();
299        }
300        Ok(())
301    }
302}
303
304impl StoredRotation {
305    /// Apply the inverse rotation to a batch of `dim`-sized slices in one
306    /// call. The d×d matrix is already in memory; this is the same
307    /// per-vector work as `apply_inverse` repeated N times. The win is
308    /// just the loop / branch amortization on a tight inner loop.
309    pub fn apply_inverse_batch(&self, inputs: &[&[f32]]) -> Result<Vec<Vec<f32>>> {
310        if inputs.is_empty() {
311            return Ok(Vec::new());
312        }
313        let dim = self.dim;
314        let mut outputs: Vec<Vec<f32>> = Vec::with_capacity(inputs.len());
315        for input in inputs {
316            if input.len() != dim {
317                return Err(TurboQuantError::DimensionMismatch {
318                    expected: dim,
319                    got: input.len(),
320                });
321            }
322            let mut out = vec![0.0f32; dim];
323            for i in 0..dim {
324                out[i] = self
325                    .matrix
326                    .column(i)
327                    .iter()
328                    .zip(input.iter())
329                    .map(|(r, y)| r * y)
330                    .sum();
331            }
332            outputs.push(out);
333        }
334        Ok(outputs)
335    }
336}
337
338/// Generate a d×d orthogonal matrix via QR decomposition of a random Gaussian
339/// matrix. The resulting Q is Haar-distributed (uniformly random orthogonal).
340fn generate_orthogonal(dim: usize, seed: u64) -> Result<DMatrix<f32>> {
341    let mut rng = ChaCha8Rng::seed_from_u64(seed);
342    let dist = StandardNormal;
343
344    // Sample a d×d matrix with i.i.d. N(0,1) entries.
345    let data: Vec<f32> = (0..dim * dim).map(|_| dist.sample(&mut rng)).collect();
346
347    // nalgebra is column-major; DMatrix::from_vec(rows, cols, data) fills column by column.
348    let m = DMatrix::from_vec(dim, dim, data);
349
350    let qr = m.qr();
351    let q = qr.q();
352
353    // QR decomposition can return Q with det = -1. Fix the sign to ensure det = +1
354    // (a proper rotation rather than an improper one with reflection).
355    let r = qr.r();
356    let signs: Vec<f32> = (0..dim)
357        .map(|i| if r[(i, i)] >= 0.0 { 1.0 } else { -1.0 })
358        .collect();
359
360    let mut corrected = q;
361    for (j, &s) in signs.iter().enumerate() {
362        if s < 0.0 {
363            for i in 0..dim {
364                corrected[(i, j)] *= -1.0;
365            }
366        }
367    }
368
369    Ok(corrected)
370}
371
372fn check_dim(got: usize, expected: usize) -> Result<()> {
373    if got != expected {
374        return Err(TurboQuantError::DimensionMismatch { expected, got });
375    }
376    Ok(())
377}
378
379fn fwht_normalized(values: &mut [f32]) {
380    let n = values.len();
381    let mut step = 1;
382    while step < n {
383        let block = step * 2;
384        for start in (0..n).step_by(block) {
385            for offset in 0..step {
386                let a = values[start + offset];
387                let b = values[start + offset + step];
388                values[start + offset] = a + b;
389                values[start + offset + step] = a - b;
390            }
391        }
392        step = block;
393    }
394    let scale = (n as f32).sqrt().recip();
395    for value in values {
396        *value *= scale;
397    }
398}
399
400mod matrix_serde {
401    use nalgebra::DMatrix;
402    use serde::{Deserialize, Deserializer, Serialize, Serializer};
403
404    #[derive(Serialize, Deserialize)]
405    struct MatrixProxy {
406        rows: usize,
407        cols: usize,
408        data: Vec<f32>,
409    }
410
411    pub fn serialize<S: Serializer>(
412        m: &DMatrix<f32>,
413        s: S,
414    ) -> std::result::Result<S::Ok, S::Error> {
415        MatrixProxy {
416            rows: m.nrows(),
417            cols: m.ncols(),
418            data: m.as_slice().to_vec(),
419        }
420        .serialize(s)
421    }
422
423    pub fn deserialize<'de, D: Deserializer<'de>>(
424        d: D,
425    ) -> std::result::Result<DMatrix<f32>, D::Error> {
426        let p = MatrixProxy::deserialize(d)?;
427        Ok(DMatrix::from_vec(p.rows, p.cols, p.data))
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn rotation_is_deterministic_for_same_seed() {
437        let r1 = StoredRotation::new(8, 42).unwrap();
438        let r2 = StoredRotation::new(8, 42).unwrap();
439        assert_eq!(r1.matrix.as_slice(), r2.matrix.as_slice());
440    }
441
442    #[test]
443    fn rotation_differs_across_seeds() {
444        let r1 = StoredRotation::new(8, 1).unwrap();
445        let r2 = StoredRotation::new(8, 2).unwrap();
446        assert_ne!(r1.matrix.as_slice(), r2.matrix.as_slice());
447    }
448
449    #[test]
450    fn rotation_is_orthogonal_rrt_equals_identity() {
451        let r = StoredRotation::new(16, 7).unwrap();
452        let m = &r.matrix;
453        let product = m.transpose() * m;
454        for i in 0..16 {
455            for j in 0..16 {
456                let expected = if i == j { 1.0f32 } else { 0.0f32 };
457                let got = product[(i, j)];
458                assert!(
459                    (got - expected).abs() < 1e-5,
460                    "RᵀR[{i},{j}] = {got}, expected {expected}"
461                );
462            }
463        }
464    }
465
466    #[test]
467    fn apply_inverse_recovers_input() {
468        let r = StoredRotation::new(8, 99).unwrap();
469        let x = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
470        let mut y = vec![0.0f32; 8];
471        let mut recovered = vec![0.0f32; 8];
472
473        r.apply(&x, &mut y).unwrap();
474        r.apply_inverse(&y, &mut recovered).unwrap();
475
476        for (orig, rec) in x.iter().zip(recovered.iter()) {
477            assert!((orig - rec).abs() < 1e-5, "orig={orig}, recovered={rec}");
478        }
479    }
480
481    #[test]
482    fn rotation_preserves_inner_products() {
483        // For orthogonal R: <Rx, Ry> = <x, y>
484        let r = StoredRotation::new(8, 13).unwrap();
485        let x = vec![1.0f32, 0.5, -1.0, 2.0, 0.1, -0.3, 1.5, 0.8];
486        let y = vec![0.2f32, -1.0, 0.5, 1.0, -0.5, 0.3, 0.9, -0.7];
487        let mut rx = vec![0.0f32; 8];
488        let mut ry = vec![0.0f32; 8];
489
490        r.apply(&x, &mut rx).unwrap();
491        r.apply(&y, &mut ry).unwrap();
492
493        let ip_original: f32 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
494        let ip_rotated: f32 = rx.iter().zip(ry.iter()).map(|(a, b)| a * b).sum();
495
496        assert!((ip_original - ip_rotated).abs() < 1e-4);
497    }
498
499    #[test]
500    fn zero_dimension_is_rejected() {
501        assert!(StoredRotation::new(0, 0).is_err());
502    }
503
504    #[test]
505    fn serialization_roundtrip() {
506        let r = StoredRotation::new(8, 55).unwrap();
507        let json = serde_json::to_string(&r).unwrap();
508        let restored: StoredRotation = serde_json::from_str(&json).unwrap();
509        assert_eq!(r.matrix.as_slice(), restored.matrix.as_slice());
510    }
511}