Skip to main content

tinyquant_core/codec/
rotation_matrix.rs

1//! Canonical orthogonal rotation matrix (`ChaCha` → faer QR → sign correction).
2//!
3//! Pipeline:
4//!
5//! 1. Fill a `dim * dim` row-major `f64` buffer from [`ChaChaGaussianStream`].
6//! 2. Load the buffer into a faer `Mat<f64>` and compute `A = Q R` via
7//!    [`faer::Mat::qr`].
8//! 3. Apply the Haar-measure sign correction `Q[:, j] *= sign(R[j, j])`
9//!    so that the resulting orthogonal matrix is uniquely determined by
10//!    the RNG stream (mirrors the Python reference).
11//! 4. Store the corrected `Q` in row-major order inside an `Arc<[f64]>`.
12//!
13//! Numerical contract:
14//!
15//! * `apply_into` promotes `f32` to `f64`, does the matmul in `f64`, and
16//!   casts the result back to `f32`. This matches `NumPy`'s implicit
17//!   promotion for `float64 @ float32` and is required for the round-trip
18//!   parity target of `< 1e-5`.
19//! * `apply_inverse_into` uses the stored matrix's transpose (valid
20//!   because it's orthogonal), again with an `f64` accumulator.
21//!
22//! See `docs/design/rust/numerical-semantics.md` §R1 for the full recipe
23//! and rationale.
24
25use alloc::sync::Arc;
26use alloc::vec;
27
28use faer::Mat;
29use libm::fabs;
30
31use crate::codec::codec_config::CodecConfig;
32use crate::codec::gaussian::ChaChaGaussianStream;
33use crate::errors::CodecError;
34
35/// Deterministically generated orthogonal matrix for vector preconditioning.
36///
37/// The inner storage is `Arc<[f64]>` so `Clone` is O(1) — callers can
38/// pass copies between tasks without reallocating the `dim * dim` buffer.
39#[derive(Clone, Debug)]
40pub struct RotationMatrix {
41    matrix: Arc<[f64]>,
42    seed: u64,
43    dimension: u32,
44}
45
46impl RotationMatrix {
47    /// Build a rotation matrix for the given [`CodecConfig`].
48    #[inline]
49    pub fn from_config(config: &CodecConfig) -> Self {
50        Self::build(config.seed(), config.dimension())
51    }
52
53    /// Build a rotation matrix for the `(seed, dimension)` pair.
54    ///
55    /// # Panics
56    ///
57    /// Panics if `dimension == 0`. In the normal flow, dimensions reach
58    /// this function only via a validated [`CodecConfig`], so this cannot
59    /// be triggered by safe public APIs of the crate.
60    #[allow(clippy::indexing_slicing)] // bounds are statically derived from `dim`
61    pub fn build(seed: u64, dimension: u32) -> Self {
62        assert!(
63            dimension > 0,
64            "RotationMatrix::build requires dimension > 0"
65        );
66        let dim = dimension as usize;
67
68        // Step 1: fill the dim*dim row-major buffer from the ChaCha stream.
69        let mut data = vec![0.0f64; dim * dim];
70        let mut stream = ChaChaGaussianStream::new(seed);
71        for slot in &mut data {
72            *slot = stream.next_f64();
73        }
74
75        // Step 2: load into faer (column-major internal storage) and QR.
76        // We pass a closure that reads from our row-major buffer.
77        let a = Mat::<f64>::from_fn(dim, dim, |i, j| data[i * dim + j]);
78        let qr = a.qr();
79        let q = qr.compute_q();
80        let r = qr.compute_r();
81
82        // Step 3: Haar sign correction — multiply column j of Q by the
83        // sign of R[j, j]. The convention `sign(0) = 1` matches numpy
84        // when diag elements collide at exactly zero (rare for ChaCha).
85        let mut row_major = vec![0.0f64; dim * dim];
86        for j in 0..dim {
87            let diag = r[(j, j)];
88            let sign = if diag >= 0.0 { 1.0 } else { -1.0 };
89            for i in 0..dim {
90                row_major[i * dim + j] = q[(i, j)] * sign;
91            }
92        }
93
94        Self {
95            matrix: Arc::from(row_major.into_boxed_slice()),
96            seed,
97            dimension,
98        }
99    }
100
101    /// Borrow the row-major `dim * dim` matrix buffer.
102    #[inline]
103    pub fn matrix(&self) -> &[f64] {
104        &self.matrix
105    }
106
107    /// The row/column count of the square matrix.
108    #[inline]
109    pub const fn dimension(&self) -> u32 {
110        self.dimension
111    }
112
113    /// The seed that generated this matrix.
114    #[inline]
115    pub const fn seed(&self) -> u64 {
116        self.seed
117    }
118
119    /// Rotate `input` into `output`: `output = matrix @ input`.
120    ///
121    /// # Errors
122    ///
123    /// Returns [`CodecError::LengthMismatch`] if either slice length
124    /// differs from [`Self::dimension`].
125    pub fn apply_into(&self, input: &[f32], output: &mut [f32]) -> Result<(), CodecError> {
126        let dim = self.dimension as usize;
127        if input.len() != dim || output.len() != dim {
128            return Err(CodecError::LengthMismatch {
129                left: input.len(),
130                right: output.len(),
131            });
132        }
133        for (row, out_slot) in self.matrix.chunks_exact(dim).zip(output.iter_mut()) {
134            let acc: f64 = row
135                .iter()
136                .zip(input.iter())
137                .map(|(m, x)| m * f64::from(*x))
138                .sum();
139            #[allow(clippy::cast_possible_truncation)]
140            {
141                *out_slot = acc as f32;
142            }
143        }
144        Ok(())
145    }
146
147    /// Apply the inverse rotation: `output = matrix^T @ input`.
148    ///
149    /// Valid because the matrix is orthogonal.
150    ///
151    /// # Errors
152    ///
153    /// Returns [`CodecError::LengthMismatch`] if either slice length
154    /// differs from [`Self::dimension`].
155    pub fn apply_inverse_into(&self, input: &[f32], output: &mut [f32]) -> Result<(), CodecError> {
156        let dim = self.dimension as usize;
157        if input.len() != dim || output.len() != dim {
158            return Err(CodecError::LengthMismatch {
159                left: input.len(),
160                right: output.len(),
161            });
162        }
163        // Accumulate `output[j] = Σ_i matrix[i, j] * input[i]` in `f64`
164        // to preserve precision, then cast to `f32` once per output.
165        let mut scratch = alloc::vec![0.0f64; dim];
166        for (row, x) in self.matrix.chunks_exact(dim).zip(input.iter()) {
167            let xf = f64::from(*x);
168            for (scratch_slot, m) in scratch.iter_mut().zip(row.iter()) {
169                *scratch_slot += m * xf;
170            }
171        }
172        for (out_slot, value) in output.iter_mut().zip(scratch.iter()) {
173            #[allow(clippy::cast_possible_truncation)]
174            {
175                *out_slot = *value as f32;
176            }
177        }
178        Ok(())
179    }
180
181    /// Return `true` if `matrix @ matrix^T` equals the identity within
182    /// `tol` on every entry.
183    ///
184    /// This is `O(n^3)`; reserve for tests and low-frequency sanity checks.
185    pub fn verify_orthogonality(&self, tol: f64) -> bool {
186        let dim = self.dimension as usize;
187        for (i, row_i) in self.matrix.chunks_exact(dim).enumerate() {
188            for (j, row_j) in self.matrix.chunks_exact(dim).enumerate() {
189                let acc: f64 = row_i.iter().zip(row_j.iter()).map(|(a, b)| a * b).sum();
190                let expected = if i == j { 1.0 } else { 0.0 };
191                if fabs(acc - expected) > tol {
192                    return false;
193                }
194            }
195        }
196        true
197    }
198}