tinyquant_core/codec/rotation_matrix.rs
1//! Canonical orthogonal rotation matrix (`ChaCha` → faer QR → sign correction).
2//!
3//! Pipeline:
4//!
5//! 1. Fill a `dim * dim` row-major `f64` buffer from [`ChaChaGaussianStream`].
6//! 2. Load the buffer into a faer `Mat<f64>` and compute `A = Q R` via
7//! [`faer::Mat::qr`].
8//! 3. Apply the Haar-measure sign correction `Q[:, j] *= sign(R[j, j])`
9//! so that the resulting orthogonal matrix is uniquely determined by
10//! the RNG stream (mirrors the Python reference).
11//! 4. Store the corrected `Q` in row-major order inside an `Arc<[f64]>`.
12//!
13//! Numerical contract:
14//!
15//! * `apply_into` promotes `f32` to `f64`, does the matmul in `f64`, and
16//! casts the result back to `f32`. This matches `NumPy`'s implicit
17//! promotion for `float64 @ float32` and is required for the round-trip
18//! parity target of `< 1e-5`.
19//! * `apply_inverse_into` uses the stored matrix's transpose (valid
20//! because it's orthogonal), again with an `f64` accumulator.
21//!
22//! See `docs/design/rust/numerical-semantics.md` §R1 for the full recipe
23//! and rationale.
24
25use alloc::sync::Arc;
26use alloc::vec;
27
28use faer::Mat;
29use libm::fabs;
30
31use crate::codec::codec_config::CodecConfig;
32use crate::codec::gaussian::ChaChaGaussianStream;
33use crate::errors::CodecError;
34
35/// Deterministically generated orthogonal matrix for vector preconditioning.
36///
37/// The inner storage is `Arc<[f64]>` so `Clone` is O(1) — callers can
38/// pass copies between tasks without reallocating the `dim * dim` buffer.
39#[derive(Clone, Debug)]
40pub struct RotationMatrix {
41 matrix: Arc<[f64]>,
42 seed: u64,
43 dimension: u32,
44}
45
46impl RotationMatrix {
47 /// Build a rotation matrix for the given [`CodecConfig`].
48 #[inline]
49 pub fn from_config(config: &CodecConfig) -> Self {
50 Self::build(config.seed(), config.dimension())
51 }
52
53 /// Build a rotation matrix for the `(seed, dimension)` pair.
54 ///
55 /// # Panics
56 ///
57 /// Panics if `dimension == 0`. In the normal flow, dimensions reach
58 /// this function only via a validated [`CodecConfig`], so this cannot
59 /// be triggered by safe public APIs of the crate.
60 #[allow(clippy::indexing_slicing)] // bounds are statically derived from `dim`
61 pub fn build(seed: u64, dimension: u32) -> Self {
62 assert!(
63 dimension > 0,
64 "RotationMatrix::build requires dimension > 0"
65 );
66 let dim = dimension as usize;
67
68 // Step 1: fill the dim*dim row-major buffer from the ChaCha stream.
69 let mut data = vec![0.0f64; dim * dim];
70 let mut stream = ChaChaGaussianStream::new(seed);
71 for slot in &mut data {
72 *slot = stream.next_f64();
73 }
74
75 // Step 2: load into faer (column-major internal storage) and QR.
76 // We pass a closure that reads from our row-major buffer.
77 let a = Mat::<f64>::from_fn(dim, dim, |i, j| data[i * dim + j]);
78 let qr = a.qr();
79 let q = qr.compute_q();
80 let r = qr.compute_r();
81
82 // Step 3: Haar sign correction — multiply column j of Q by the
83 // sign of R[j, j]. The convention `sign(0) = 1` matches numpy
84 // when diag elements collide at exactly zero (rare for ChaCha).
85 let mut row_major = vec![0.0f64; dim * dim];
86 for j in 0..dim {
87 let diag = r[(j, j)];
88 let sign = if diag >= 0.0 { 1.0 } else { -1.0 };
89 for i in 0..dim {
90 row_major[i * dim + j] = q[(i, j)] * sign;
91 }
92 }
93
94 Self {
95 matrix: Arc::from(row_major.into_boxed_slice()),
96 seed,
97 dimension,
98 }
99 }
100
101 /// Borrow the row-major `dim * dim` matrix buffer.
102 #[inline]
103 pub fn matrix(&self) -> &[f64] {
104 &self.matrix
105 }
106
107 /// The row/column count of the square matrix.
108 #[inline]
109 pub const fn dimension(&self) -> u32 {
110 self.dimension
111 }
112
113 /// The seed that generated this matrix.
114 #[inline]
115 pub const fn seed(&self) -> u64 {
116 self.seed
117 }
118
119 /// Rotate `input` into `output`: `output = matrix @ input`.
120 ///
121 /// # Errors
122 ///
123 /// Returns [`CodecError::LengthMismatch`] if either slice length
124 /// differs from [`Self::dimension`].
125 pub fn apply_into(&self, input: &[f32], output: &mut [f32]) -> Result<(), CodecError> {
126 let dim = self.dimension as usize;
127 if input.len() != dim || output.len() != dim {
128 return Err(CodecError::LengthMismatch {
129 left: input.len(),
130 right: output.len(),
131 });
132 }
133 for (row, out_slot) in self.matrix.chunks_exact(dim).zip(output.iter_mut()) {
134 let acc: f64 = row
135 .iter()
136 .zip(input.iter())
137 .map(|(m, x)| m * f64::from(*x))
138 .sum();
139 #[allow(clippy::cast_possible_truncation)]
140 {
141 *out_slot = acc as f32;
142 }
143 }
144 Ok(())
145 }
146
147 /// Apply the inverse rotation: `output = matrix^T @ input`.
148 ///
149 /// Valid because the matrix is orthogonal.
150 ///
151 /// # Errors
152 ///
153 /// Returns [`CodecError::LengthMismatch`] if either slice length
154 /// differs from [`Self::dimension`].
155 pub fn apply_inverse_into(&self, input: &[f32], output: &mut [f32]) -> Result<(), CodecError> {
156 let dim = self.dimension as usize;
157 if input.len() != dim || output.len() != dim {
158 return Err(CodecError::LengthMismatch {
159 left: input.len(),
160 right: output.len(),
161 });
162 }
163 // Accumulate `output[j] = Σ_i matrix[i, j] * input[i]` in `f64`
164 // to preserve precision, then cast to `f32` once per output.
165 let mut scratch = alloc::vec![0.0f64; dim];
166 for (row, x) in self.matrix.chunks_exact(dim).zip(input.iter()) {
167 let xf = f64::from(*x);
168 for (scratch_slot, m) in scratch.iter_mut().zip(row.iter()) {
169 *scratch_slot += m * xf;
170 }
171 }
172 for (out_slot, value) in output.iter_mut().zip(scratch.iter()) {
173 #[allow(clippy::cast_possible_truncation)]
174 {
175 *out_slot = *value as f32;
176 }
177 }
178 Ok(())
179 }
180
181 /// Return `true` if `matrix @ matrix^T` equals the identity within
182 /// `tol` on every entry.
183 ///
184 /// This is `O(n^3)`; reserve for tests and low-frequency sanity checks.
185 pub fn verify_orthogonality(&self, tol: f64) -> bool {
186 let dim = self.dimension as usize;
187 for (i, row_i) in self.matrix.chunks_exact(dim).enumerate() {
188 for (j, row_j) in self.matrix.chunks_exact(dim).enumerate() {
189 let acc: f64 = row_i.iter().zip(row_j.iter()).map(|(a, b)| a * b).sum();
190 let expected = if i == j { 1.0 } else { 0.0 };
191 if fabs(acc - expected) > tol {
192 return false;
193 }
194 }
195 }
196 true
197 }
198}