turboquant/rotation.rs
1//! Walsh-Hadamard Transform and random rotation.
2//!
3//! The rotation transforms arbitrary input vectors into a known distribution
4//! using a randomized orthogonal transform, making quantization "data-oblivious".
5
6use crate::error::{require, values_match, Result, TurboQuantError};
7use crate::packed::is_valid_dim;
8
9// ---------------------------------------------------------------------------
10// Walsh-Hadamard Transform (pure Operation)
11// ---------------------------------------------------------------------------
12
13/// Fast Walsh-Hadamard Transform, in-place, O(d log d).
14///
15/// The transform is normalized by `1 / sqrt(n)` so that applying it twice
16/// returns the original vector (self-inverse property).
17///
18/// # Panics
19///
20/// Does **not** panic; callers should use [`is_valid_dim`](crate::packed::is_valid_dim)
21/// beforehand or call [`validate_rotation_inputs`].
22pub fn wht_inplace(data: &mut [f32]) {
23 let n = data.len();
24 let mut h = 1;
25 while h < n {
26 for i in (0..n).step_by(h * 2) {
27 for j in i..i + h {
28 let x = data[j];
29 let y = data[j + h];
30 data[j] = x + y;
31 data[j + h] = x - y;
32 }
33 }
34 h *= 2;
35 }
36 let norm = 1.0 / (n as f32).sqrt();
37 for v in data.iter_mut() {
38 *v *= norm;
39 }
40}
41
42// ---------------------------------------------------------------------------
43// Sign-pattern generation (pure Operation)
44// ---------------------------------------------------------------------------
45
46/// Golden-ratio constant used for deterministic hashing.
47const GOLDEN_RATIO: u64 = 0x9E37_79B9_7F4A_7C15;
48
49/// Hashes `(seed, index)` deterministically using a golden-ratio multiply-shift.
50fn golden_ratio_hash(seed: u64, index: usize) -> u64 {
51 let combined = seed.wrapping_add(index as u64);
52 combined.wrapping_mul(GOLDEN_RATIO)
53}
54
55/// Generates a deterministic sign pattern of `+1.0` / `-1.0` values.
56///
57/// The same `(dim, seed)` pair always produces the identical pattern.
58/// Each element is `+1.0` when the hash has an even least-significant bit,
59/// and `-1.0` otherwise.
60pub fn generate_sign_pattern(dim: usize, seed: u64) -> Vec<f32> {
61 (0..dim)
62 .map(|i| {
63 if golden_ratio_hash(seed, i) & 1 == 0 {
64 1.0_f32
65 } else {
66 -1.0_f32
67 }
68 })
69 .collect()
70}
71
72// ---------------------------------------------------------------------------
73// Element-wise sign flip (pure Operation)
74// ---------------------------------------------------------------------------
75
76/// Multiplies each element of `data` by the corresponding element in `signs`.
77///
78/// # Precondition
79///
80/// `data.len() == signs.len()` -- caller is responsible for ensuring this.
81fn apply_sign_flip(data: &mut [f32], signs: &[f32]) {
82 for (v, &s) in data.iter_mut().zip(signs.iter()) {
83 *v *= s;
84 }
85}
86
87// ---------------------------------------------------------------------------
88// Argument validation helpers (pure Operation)
89// ---------------------------------------------------------------------------
90
91/// Validates rotation inputs and returns `Ok(())` or the appropriate error.
92///
93/// Pure Integration: only calls `require`, `is_valid_dim` (from `packed`),
94/// and `values_match` (from `error`).
95fn validate_rotation_inputs(data_len: usize, sign_len: usize) -> Result<()> {
96 require(
97 is_valid_dim(data_len),
98 TurboQuantError::InvalidDimension(data_len),
99 )?;
100 require(
101 values_match(data_len, sign_len),
102 TurboQuantError::DimensionMismatch {
103 expected: data_len,
104 actual: sign_len,
105 },
106 )
107}
108
109// ---------------------------------------------------------------------------
110// Rotation order (used to DRY forward / inverse rotation)
111// ---------------------------------------------------------------------------
112
113/// Selects the order of operations in the rotation transform.
114pub enum RotationOrder {
115 /// Forward: sign flip first, then WHT.
116 Forward,
117 /// Inverse: WHT first, then sign flip.
118 Inverse,
119}
120
121/// Applies a rotation transform to `data` using the given `sign_pattern`.
122///
123/// - [`RotationOrder::Forward`]: element-wise sign flip, then WHT.
124/// - [`RotationOrder::Inverse`]: WHT, then element-wise sign flip.
125///
126/// The transform is orthogonal and preserves the L2 norm of the input.
127/// Forward followed by Inverse (or vice versa) recovers the original vector.
128///
129/// # Errors
130///
131/// Returns [`TurboQuantError::InvalidDimension`] if `data.len()` is not a
132/// power of two, or [`TurboQuantError::DimensionMismatch`] if `data` and
133/// `sign_pattern` differ in length.
134///
135/// Integration: validates via `validate_rotation_inputs`, then applies the
136/// two steps in the order determined by `order`.
137pub fn rotate(data: &mut [f32], sign_pattern: &[f32], order: RotationOrder) -> Result<()> {
138 validate_rotation_inputs(data.len(), sign_pattern.len())?;
139 match order {
140 RotationOrder::Forward => {
141 apply_sign_flip(data, sign_pattern);
142 wht_inplace(data);
143 }
144 RotationOrder::Inverse => {
145 wht_inplace(data);
146 apply_sign_flip(data, sign_pattern);
147 }
148 }
149 Ok(())
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 /// Small dimension used for rotation round-trip and acceptance tests.
157 const TEST_SMALL_DIM: usize = 4;
158 /// Seed used for deterministic sign-pattern generation in rotation tests.
159 const TEST_SEED: u64 = 42;
160 /// Dimension used for sign-pattern generation tests (must be a power of two).
161 const TEST_SIGN_PATTERN_DIM: usize = 128;
162 /// Seed used for deterministic sign-pattern generation tests.
163 const TEST_SIGN_PATTERN_SEED: u64 = 99;
164 /// Index used in golden_ratio_hash determinism test.
165 const TEST_HASH_INDEX: usize = 7;
166
167 // -- is_valid_dim --------------------------------------------------
168
169 #[test]
170 fn is_valid_dim_accepts_powers_of_two() {
171 assert!(is_valid_dim(1));
172 assert!(is_valid_dim(2));
173 assert!(is_valid_dim(64));
174 assert!(is_valid_dim(256));
175 }
176
177 #[test]
178 fn is_valid_dim_rejects_invalid() {
179 assert!(!is_valid_dim(0));
180 assert!(!is_valid_dim(3));
181 assert!(!is_valid_dim(100));
182 }
183
184 // -- golden_ratio_hash ---------------------------------------------------
185
186 #[test]
187 fn golden_ratio_hash_is_deterministic() {
188 let a = golden_ratio_hash(TEST_SEED, TEST_HASH_INDEX);
189 let b = golden_ratio_hash(TEST_SEED, TEST_HASH_INDEX);
190 assert_eq!(a, b);
191 }
192
193 // -- sign_pattern_elements_are_plus_or_minus_one -------------------------
194
195 #[test]
196 fn sign_pattern_elements_are_plus_or_minus_one() {
197 let pattern = generate_sign_pattern(TEST_SIGN_PATTERN_DIM, TEST_SIGN_PATTERN_SEED);
198 assert_sign_values_valid(&pattern);
199 }
200
201 /// Pure assertion helper: checks every element is +1.0 or -1.0.
202 ///
203 /// Separated from generation so the test is not mixing calls with loop logic.
204 fn assert_sign_values_valid(pattern: &[f32]) {
205 for &v in pattern {
206 assert!(v == 1.0 || v == -1.0);
207 }
208 }
209
210 // -- validate_rotation_args (via rotate / inverse_rotate) ----------------
211
212 #[test]
213 fn rotate_accepts_matching_pow2_dims() {
214 let mut data = vec![1.0, 2.0, 3.0, 4.0];
215 let signs = generate_sign_pattern(TEST_SMALL_DIM, TEST_SEED);
216 assert!(rotate(&mut data, &signs, RotationOrder::Forward).is_ok());
217 }
218
219 #[test]
220 fn rotate_rejects_non_pow2() {
221 let mut data = vec![1.0, 2.0, 3.0];
222 let signs = vec![1.0, -1.0, 1.0];
223 assert!(rotate(&mut data, &signs, RotationOrder::Forward).is_err());
224 }
225
226 #[test]
227 fn rotate_rejects_mismatched_lengths() {
228 let mut data = vec![1.0, 2.0, 3.0, 4.0];
229 let signs = vec![1.0, -1.0];
230 assert!(rotate(&mut data, &signs, RotationOrder::Forward).is_err());
231 }
232
233 // -- wht_inplace ---------------------------------------------------------
234
235 #[test]
236 fn wht_inplace_known_vector() {
237 // WHT of [1, 1, 1, 1] normalized by 1/sqrt(4) = 0.5
238 // -> [2, 0, 0, 0] * 0.5 = [1, 0, 0, 0]
239 let mut data = vec![1.0, 1.0, 1.0, 1.0];
240 wht_inplace(&mut data);
241 assert!((data[0] - 2.0).abs() < 1e-6);
242 assert!((data[1] - 0.0).abs() < 1e-6);
243 assert!((data[2] - 0.0).abs() < 1e-6);
244 assert!((data[3] - 0.0).abs() < 1e-6);
245 }
246
247 #[test]
248 fn wht_inplace_self_inverse() {
249 let original = vec![1.0, 2.0, 3.0, 4.0];
250 let mut data = original.clone();
251 wht_inplace(&mut data);
252 wht_inplace(&mut data);
253 for (a, b) in data.iter().zip(original.iter()) {
254 assert!((a - b).abs() < 1e-5);
255 }
256 }
257
258 // -- apply_sign_flip -----------------------------------------------------
259
260 #[test]
261 fn apply_sign_flip_basic() {
262 let mut data = vec![2.0, -3.0, 4.0, -5.0];
263 let signs = vec![1.0, -1.0, -1.0, 1.0];
264 apply_sign_flip(&mut data, &signs);
265 assert_eq!(data, vec![2.0, 3.0, -4.0, -5.0]);
266 }
267
268 // -- values_match (from error.rs) -----------------------------------------
269
270 #[test]
271 fn values_match_equal() {
272 assert!(values_match(4, 4));
273 assert!(values_match(128, 128));
274 }
275
276 #[test]
277 fn values_match_unequal() {
278 assert!(!values_match(4, 8));
279 assert!(!values_match(0, 1));
280 }
281
282 // -- roundtrip rotation --------------------------------------------------
283
284 #[test]
285 fn rotate_inverse_rotate_roundtrip() {
286 let original = vec![1.0, 2.0, 3.0, 4.0];
287 let mut data = original.clone();
288 let signs = generate_sign_pattern(TEST_SMALL_DIM, TEST_SEED);
289 rotate(&mut data, &signs, RotationOrder::Forward).unwrap();
290 rotate(&mut data, &signs, RotationOrder::Inverse).unwrap();
291 for (a, b) in data.iter().zip(original.iter()) {
292 assert!((a - b).abs() < 1e-5);
293 }
294 }
295}