tinyquant_core/codec/rotation_matrix.rs
1//! Canonical orthogonal rotation matrix (`ChaCha` → faer QR → sign correction).
2//!
3//! Pipeline:
4//!
5//! 1. Fill a `dim * dim` row-major `f64` buffer from [`ChaChaGaussianStream`].
6//! 2. Load the buffer into a faer `Mat<f64>` and compute `A = Q R` via
7//! [`faer::Mat::qr`].
8//! 3. Apply the Haar-measure sign correction `Q[:, j] *= sign(R[j, j])`
9//! so that the resulting orthogonal matrix is uniquely determined by
10//! the RNG stream (mirrors the Python reference).
11//! 4. Store the corrected `Q` in row-major order inside an `Arc<[f64]>`.
12//!
13//! Numerical contract:
14//!
15//! * `apply_into` promotes `f32` to `f64`, does the matmul in `f64`, and
16//! casts the result back to `f32`. This matches `NumPy`'s implicit
17//! promotion for `float64 @ float32` and is required for the round-trip
18//! parity target of `< 1e-5`.
19//! * `apply_inverse_into` uses the stored matrix's transpose (valid
20//! because it's orthogonal), again with an `f64` accumulator.
21//!
22//! See `docs/design/rust/numerical-semantics.md` §R1 for the full recipe
23//! and rationale.
24
25use alloc::sync::Arc;
26use alloc::vec;
27
28use faer::Mat;
29use libm::fabs;
30
31use crate::codec::codec_config::{CodecConfig, MAX_DIMENSION};
32use crate::codec::gaussian::ChaChaGaussianStream;
33use crate::errors::CodecError;
34
35/// Deterministically generated orthogonal matrix for vector preconditioning.
36///
37/// The inner storage is `Arc<[f64]>` so `Clone` is O(1) — callers can
38/// pass copies between tasks without reallocating the `dim * dim` buffer.
39#[derive(Clone, Debug)]
40pub struct RotationMatrix {
41 matrix: Arc<[f64]>,
42 seed: u64,
43 dimension: u32,
44}
45
46impl RotationMatrix {
47 /// Build a rotation matrix for the given [`CodecConfig`].
48 #[inline]
49 pub fn from_config(config: &CodecConfig) -> Self {
50 Self::build(config.seed(), config.dimension())
51 }
52
53 /// Build a rotation matrix for the `(seed, dimension)` pair.
54 ///
55 /// # Panics
56 ///
57 /// Panics if `dimension == 0` or `dimension > MAX_DIMENSION`. In the
58 /// normal flow, dimensions reach this function only via a validated
59 /// [`CodecConfig`] or the `from_seed_and_dim` `PyO3` wrapper, both of
60 /// which return errors instead of panicking. The asserts here are
61 /// defence-in-depth against a future caller bypassing those gates.
62 #[allow(clippy::indexing_slicing)] // bounds are statically derived from `dim`
63 pub fn build(seed: u64, dimension: u32) -> Self {
64 assert!(
65 dimension > 0,
66 "RotationMatrix::build requires dimension > 0"
67 );
68 assert!(
69 dimension <= MAX_DIMENSION,
70 "RotationMatrix::build requires dimension <= {MAX_DIMENSION}, got {dimension}",
71 );
72 let dim = dimension as usize;
73
74 // Step 1: fill the dim*dim row-major buffer from the ChaCha stream.
75 let mut data = vec![0.0f64; dim * dim];
76 let mut stream = ChaChaGaussianStream::new(seed);
77 for slot in &mut data {
78 *slot = stream.next_f64();
79 }
80
81 // Step 2: load into faer (column-major internal storage) and QR.
82 // We pass a closure that reads from our row-major buffer.
83 //
84 // SIMD divergence (R19) is handled by the AVX2 feature cap in
85 // rust/.cargo/config.toml, which empirically produces a bit-exact
86 // QR output across both x86_64 (AVX2) and aarch64 (NEON) runners
87 // — see commit e04ce5c. A `Parallelism::None` guard around `qr()`
88 // was considered as additional defence-in-depth but rejected
89 // because faer 0.19's `set_global_parallelism` mutates a
90 // process-wide atomic with no per-call alternative, and forcing
91 // serial reduction changes the output bit pattern on multi-core
92 // Linux runners — invalidating the frozen `seed_42_dim_*` fixtures
93 // that the cross-architecture parity tests rely on.
94 let a = Mat::<f64>::from_fn(dim, dim, |i, j| data[i * dim + j]);
95 let qr = a.qr();
96 let q = qr.compute_q();
97 let r = qr.compute_r();
98
99 // Step 3: Haar sign correction — multiply column j of Q by the
100 // sign of R[j, j]. The convention `sign(0) = 1` matches numpy
101 // when diag elements collide at exactly zero (rare for ChaCha).
102 let mut row_major = vec![0.0f64; dim * dim];
103 for j in 0..dim {
104 let diag = r[(j, j)];
105 let sign = if diag >= 0.0 { 1.0 } else { -1.0 };
106 for i in 0..dim {
107 row_major[i * dim + j] = q[(i, j)] * sign;
108 }
109 }
110
111 Self {
112 matrix: Arc::from(row_major.into_boxed_slice()),
113 seed,
114 dimension,
115 }
116 }
117
118 /// Borrow the row-major `dim * dim` matrix buffer.
119 #[inline]
120 pub fn matrix(&self) -> &[f64] {
121 &self.matrix
122 }
123
124 /// The row/column count of the square matrix.
125 #[inline]
126 pub const fn dimension(&self) -> u32 {
127 self.dimension
128 }
129
130 /// The seed that generated this matrix.
131 #[inline]
132 pub const fn seed(&self) -> u64 {
133 self.seed
134 }
135
136 /// Rotate `input` into `output`: `output = matrix @ input`.
137 ///
138 /// # Errors
139 ///
140 /// Returns [`CodecError::LengthMismatch`] if either slice length
141 /// differs from [`Self::dimension`].
142 pub fn apply_into(&self, input: &[f32], output: &mut [f32]) -> Result<(), CodecError> {
143 let dim = self.dimension as usize;
144 if input.len() != dim || output.len() != dim {
145 return Err(CodecError::LengthMismatch {
146 left: input.len(),
147 right: output.len(),
148 });
149 }
150 for (row, out_slot) in self.matrix.chunks_exact(dim).zip(output.iter_mut()) {
151 let acc: f64 = row
152 .iter()
153 .zip(input.iter())
154 .map(|(m, x)| m * f64::from(*x))
155 .sum();
156 #[allow(clippy::cast_possible_truncation)]
157 {
158 *out_slot = acc as f32;
159 }
160 }
161 Ok(())
162 }
163
164 /// Apply the inverse rotation: `output = matrix^T @ input`.
165 ///
166 /// Valid because the matrix is orthogonal.
167 ///
168 /// # Errors
169 ///
170 /// Returns [`CodecError::LengthMismatch`] if either slice length
171 /// differs from [`Self::dimension`].
172 pub fn apply_inverse_into(&self, input: &[f32], output: &mut [f32]) -> Result<(), CodecError> {
173 let dim = self.dimension as usize;
174 if input.len() != dim || output.len() != dim {
175 return Err(CodecError::LengthMismatch {
176 left: input.len(),
177 right: output.len(),
178 });
179 }
180 // Accumulate `output[j] = Σ_i matrix[i, j] * input[i]` in `f64`
181 // to preserve precision, then cast to `f32` once per output.
182 let mut scratch = alloc::vec![0.0f64; dim];
183 for (row, x) in self.matrix.chunks_exact(dim).zip(input.iter()) {
184 let xf = f64::from(*x);
185 for (scratch_slot, m) in scratch.iter_mut().zip(row.iter()) {
186 *scratch_slot += m * xf;
187 }
188 }
189 for (out_slot, value) in output.iter_mut().zip(scratch.iter()) {
190 #[allow(clippy::cast_possible_truncation)]
191 {
192 *out_slot = *value as f32;
193 }
194 }
195 Ok(())
196 }
197
198 /// Return `true` if `matrix @ matrix^T` equals the identity within
199 /// `tol` on every entry.
200 ///
201 /// This is `O(n^3)`; reserve for tests and low-frequency sanity checks.
202 pub fn verify_orthogonality(&self, tol: f64) -> bool {
203 let dim = self.dimension as usize;
204 for (i, row_i) in self.matrix.chunks_exact(dim).enumerate() {
205 for (j, row_j) in self.matrix.chunks_exact(dim).enumerate() {
206 let acc: f64 = row_i.iter().zip(row_j.iter()).map(|(a, b)| a * b).sum();
207 let expected = if i == j { 1.0 } else { 0.0 };
208 if fabs(acc - expected) > tol {
209 return false;
210 }
211 }
212 }
213 true
214 }
215}