Skip to main content

tinyquant_core/codec/
rotation_matrix.rs

1//! Canonical orthogonal rotation matrix (`ChaCha` → faer QR → sign correction).
2//!
3//! Pipeline:
4//!
5//! 1. Fill a `dim * dim` row-major `f64` buffer from [`ChaChaGaussianStream`].
6//! 2. Load the buffer into a faer `Mat<f64>` and compute `A = Q R` via
7//!    [`faer::Mat::qr`].
8//! 3. Apply the Haar-measure sign correction `Q[:, j] *= sign(R[j, j])`
9//!    so that the resulting orthogonal matrix is uniquely determined by
10//!    the RNG stream (mirrors the Python reference).
11//! 4. Store the corrected `Q` in row-major order inside an `Arc<[f64]>`.
12//!
13//! Numerical contract:
14//!
15//! * `apply_into` promotes `f32` to `f64`, does the matmul in `f64`, and
16//!   casts the result back to `f32`. This matches `NumPy`'s implicit
17//!   promotion for `float64 @ float32` and is required for the round-trip
18//!   parity target of `< 1e-5`.
19//! * `apply_inverse_into` uses the stored matrix's transpose (valid
20//!   because it's orthogonal), again with an `f64` accumulator.
21//!
22//! See `docs/design/rust/numerical-semantics.md` §R1 for the full recipe
23//! and rationale.
24
25use alloc::sync::Arc;
26use alloc::vec;
27
28use faer::Mat;
29use libm::fabs;
30
31use crate::codec::codec_config::{CodecConfig, MAX_DIMENSION};
32use crate::codec::gaussian::ChaChaGaussianStream;
33use crate::errors::CodecError;
34
35/// Deterministically generated orthogonal matrix for vector preconditioning.
36///
37/// The inner storage is `Arc<[f64]>` so `Clone` is O(1) — callers can
38/// pass copies between tasks without reallocating the `dim * dim` buffer.
39#[derive(Clone, Debug)]
40pub struct RotationMatrix {
41    matrix: Arc<[f64]>,
42    seed: u64,
43    dimension: u32,
44}
45
46impl RotationMatrix {
47    /// Build a rotation matrix for the given [`CodecConfig`].
48    #[inline]
49    pub fn from_config(config: &CodecConfig) -> Self {
50        Self::build(config.seed(), config.dimension())
51    }
52
53    /// Build a rotation matrix for the `(seed, dimension)` pair.
54    ///
55    /// # Panics
56    ///
57    /// Panics if `dimension == 0` or `dimension > MAX_DIMENSION`. In the
58    /// normal flow, dimensions reach this function only via a validated
59    /// [`CodecConfig`] or the `from_seed_and_dim` `PyO3` wrapper, both of
60    /// which return errors instead of panicking. The asserts here are
61    /// defence-in-depth against a future caller bypassing those gates.
62    #[allow(clippy::indexing_slicing)] // bounds are statically derived from `dim`
63    pub fn build(seed: u64, dimension: u32) -> Self {
64        assert!(
65            dimension > 0,
66            "RotationMatrix::build requires dimension > 0"
67        );
68        assert!(
69            dimension <= MAX_DIMENSION,
70            "RotationMatrix::build requires dimension <= {MAX_DIMENSION}, got {dimension}",
71        );
72        let dim = dimension as usize;
73
74        // Step 1: fill the dim*dim row-major buffer from the ChaCha stream.
75        let mut data = vec![0.0f64; dim * dim];
76        let mut stream = ChaChaGaussianStream::new(seed);
77        for slot in &mut data {
78            *slot = stream.next_f64();
79        }
80
81        // Step 2: load into faer (column-major internal storage) and QR.
82        // We pass a closure that reads from our row-major buffer.
83        //
84        // SIMD divergence (R19) is handled by the AVX2 feature cap in
85        // rust/.cargo/config.toml, which empirically produces a bit-exact
86        // QR output across both x86_64 (AVX2) and aarch64 (NEON) runners
87        // — see commit e04ce5c. A `Parallelism::None` guard around `qr()`
88        // was considered as additional defence-in-depth but rejected
89        // because faer 0.19's `set_global_parallelism` mutates a
90        // process-wide atomic with no per-call alternative, and forcing
91        // serial reduction changes the output bit pattern on multi-core
92        // Linux runners — invalidating the frozen `seed_42_dim_*` fixtures
93        // that the cross-architecture parity tests rely on.
94        let a = Mat::<f64>::from_fn(dim, dim, |i, j| data[i * dim + j]);
95        let qr = a.qr();
96        let q = qr.compute_q();
97        let r = qr.compute_r();
98
99        // Step 3: Haar sign correction — multiply column j of Q by the
100        // sign of R[j, j]. The convention `sign(0) = 1` matches numpy
101        // when diag elements collide at exactly zero (rare for ChaCha).
102        let mut row_major = vec![0.0f64; dim * dim];
103        for j in 0..dim {
104            let diag = r[(j, j)];
105            let sign = if diag >= 0.0 { 1.0 } else { -1.0 };
106            for i in 0..dim {
107                row_major[i * dim + j] = q[(i, j)] * sign;
108            }
109        }
110
111        Self {
112            matrix: Arc::from(row_major.into_boxed_slice()),
113            seed,
114            dimension,
115        }
116    }
117
118    /// Borrow the row-major `dim * dim` matrix buffer.
119    #[inline]
120    pub fn matrix(&self) -> &[f64] {
121        &self.matrix
122    }
123
124    /// The row/column count of the square matrix.
125    #[inline]
126    pub const fn dimension(&self) -> u32 {
127        self.dimension
128    }
129
130    /// The seed that generated this matrix.
131    #[inline]
132    pub const fn seed(&self) -> u64 {
133        self.seed
134    }
135
136    /// Rotate `input` into `output`: `output = matrix @ input`.
137    ///
138    /// # Errors
139    ///
140    /// Returns [`CodecError::LengthMismatch`] if either slice length
141    /// differs from [`Self::dimension`].
142    pub fn apply_into(&self, input: &[f32], output: &mut [f32]) -> Result<(), CodecError> {
143        let dim = self.dimension as usize;
144        if input.len() != dim || output.len() != dim {
145            return Err(CodecError::LengthMismatch {
146                left: input.len(),
147                right: output.len(),
148            });
149        }
150        for (row, out_slot) in self.matrix.chunks_exact(dim).zip(output.iter_mut()) {
151            let acc: f64 = row
152                .iter()
153                .zip(input.iter())
154                .map(|(m, x)| m * f64::from(*x))
155                .sum();
156            #[allow(clippy::cast_possible_truncation)]
157            {
158                *out_slot = acc as f32;
159            }
160        }
161        Ok(())
162    }
163
164    /// Apply the inverse rotation: `output = matrix^T @ input`.
165    ///
166    /// Valid because the matrix is orthogonal.
167    ///
168    /// # Errors
169    ///
170    /// Returns [`CodecError::LengthMismatch`] if either slice length
171    /// differs from [`Self::dimension`].
172    pub fn apply_inverse_into(&self, input: &[f32], output: &mut [f32]) -> Result<(), CodecError> {
173        let dim = self.dimension as usize;
174        if input.len() != dim || output.len() != dim {
175            return Err(CodecError::LengthMismatch {
176                left: input.len(),
177                right: output.len(),
178            });
179        }
180        // Accumulate `output[j] = Σ_i matrix[i, j] * input[i]` in `f64`
181        // to preserve precision, then cast to `f32` once per output.
182        let mut scratch = alloc::vec![0.0f64; dim];
183        for (row, x) in self.matrix.chunks_exact(dim).zip(input.iter()) {
184            let xf = f64::from(*x);
185            for (scratch_slot, m) in scratch.iter_mut().zip(row.iter()) {
186                *scratch_slot += m * xf;
187            }
188        }
189        for (out_slot, value) in output.iter_mut().zip(scratch.iter()) {
190            #[allow(clippy::cast_possible_truncation)]
191            {
192                *out_slot = *value as f32;
193            }
194        }
195        Ok(())
196    }
197
198    /// Return `true` if `matrix @ matrix^T` equals the identity within
199    /// `tol` on every entry.
200    ///
201    /// This is `O(n^3)`; reserve for tests and low-frequency sanity checks.
202    pub fn verify_orthogonality(&self, tol: f64) -> bool {
203        let dim = self.dimension as usize;
204        for (i, row_i) in self.matrix.chunks_exact(dim).enumerate() {
205            for (j, row_j) in self.matrix.chunks_exact(dim).enumerate() {
206                let acc: f64 = row_i.iter().zip(row_j.iter()).map(|(a, b)| a * b).sum();
207                let expected = if i == j { 1.0 } else { 0.0 };
208                if fabs(acc - expected) > tol {
209                    return false;
210                }
211            }
212        }
213        true
214    }
215}