Skip to main content

turboquant/
rotation.rs

1//! Walsh-Hadamard Transform and random rotation.
2//!
3//! The rotation transforms arbitrary input vectors into a known distribution
4//! using a randomized orthogonal transform, making quantization "data-oblivious".
5
6use crate::error::{require, values_match, Result, TurboQuantError};
7use crate::packed::is_valid_dim;
8
9// ---------------------------------------------------------------------------
10// Walsh-Hadamard Transform (pure Operation)
11// ---------------------------------------------------------------------------
12
13/// Fast Walsh-Hadamard Transform, in-place, O(d log d).
14///
15/// The transform is normalized by `1 / sqrt(n)` so that applying it twice
16/// returns the original vector (self-inverse property).
17///
18/// # Panics
19///
20/// Does **not** panic; callers should use [`is_valid_dim`](crate::packed::is_valid_dim)
21/// beforehand or call [`validate_rotation_inputs`].
22pub fn wht_inplace(data: &mut [f32]) {
23    let n = data.len();
24    let mut h = 1;
25    while h < n {
26        for i in (0..n).step_by(h * 2) {
27            for j in i..i + h {
28                let x = data[j];
29                let y = data[j + h];
30                data[j] = x + y;
31                data[j + h] = x - y;
32            }
33        }
34        h *= 2;
35    }
36    let norm = 1.0 / (n as f32).sqrt();
37    for v in data.iter_mut() {
38        *v *= norm;
39    }
40}
41
42// ---------------------------------------------------------------------------
43// Sign-pattern generation (pure Operation)
44// ---------------------------------------------------------------------------
45
46/// Golden-ratio constant used for deterministic hashing.
47const GOLDEN_RATIO: u64 = 0x9E37_79B9_7F4A_7C15;
48
49/// Hashes `(seed, index)` deterministically using a golden-ratio multiply-shift.
50fn golden_ratio_hash(seed: u64, index: usize) -> u64 {
51    let combined = seed.wrapping_add(index as u64);
52    combined.wrapping_mul(GOLDEN_RATIO)
53}
54
55/// Generates a deterministic sign pattern of `+1.0` / `-1.0` values.
56///
57/// The same `(dim, seed)` pair always produces the identical pattern.
58/// Each element is `+1.0` when the hash has an even least-significant bit,
59/// and `-1.0` otherwise.
60pub fn generate_sign_pattern(dim: usize, seed: u64) -> Vec<f32> {
61    (0..dim)
62        .map(|i| {
63            if golden_ratio_hash(seed, i) & 1 == 0 {
64                1.0_f32
65            } else {
66                -1.0_f32
67            }
68        })
69        .collect()
70}
71
72// ---------------------------------------------------------------------------
73// Element-wise sign flip (pure Operation)
74// ---------------------------------------------------------------------------
75
76/// Multiplies each element of `data` by the corresponding element in `signs`.
77///
78/// # Precondition
79///
80/// `data.len() == signs.len()` -- caller is responsible for ensuring this.
81fn apply_sign_flip(data: &mut [f32], signs: &[f32]) {
82    for (v, &s) in data.iter_mut().zip(signs.iter()) {
83        *v *= s;
84    }
85}
86
87// ---------------------------------------------------------------------------
88// Argument validation helpers (pure Operation)
89// ---------------------------------------------------------------------------
90
91/// Validates rotation inputs and returns `Ok(())` or the appropriate error.
92///
93/// Pure Integration: only calls `require`, `is_valid_dim` (from `packed`),
94/// and `values_match` (from `error`).
95fn validate_rotation_inputs(data_len: usize, sign_len: usize) -> Result<()> {
96    require(
97        is_valid_dim(data_len),
98        TurboQuantError::InvalidDimension(data_len),
99    )?;
100    require(
101        values_match(data_len, sign_len),
102        TurboQuantError::DimensionMismatch {
103            expected: data_len,
104            actual: sign_len,
105        },
106    )
107}
108
109// ---------------------------------------------------------------------------
110// Rotation order (used to DRY forward / inverse rotation)
111// ---------------------------------------------------------------------------
112
113/// Selects the order of operations in the rotation transform.
114pub enum RotationOrder {
115    /// Forward: sign flip first, then WHT.
116    Forward,
117    /// Inverse: WHT first, then sign flip.
118    Inverse,
119}
120
121/// Applies a rotation transform to `data` using the given `sign_pattern`.
122///
123/// - [`RotationOrder::Forward`]: element-wise sign flip, then WHT.
124/// - [`RotationOrder::Inverse`]: WHT, then element-wise sign flip.
125///
126/// The transform is orthogonal and preserves the L2 norm of the input.
127/// Forward followed by Inverse (or vice versa) recovers the original vector.
128///
129/// # Errors
130///
131/// Returns [`TurboQuantError::InvalidDimension`] if `data.len()` is not a
132/// power of two, or [`TurboQuantError::DimensionMismatch`] if `data` and
133/// `sign_pattern` differ in length.
134///
135/// Integration: validates via `validate_rotation_inputs`, then applies the
136/// two steps in the order determined by `order`.
137pub fn rotate(data: &mut [f32], sign_pattern: &[f32], order: RotationOrder) -> Result<()> {
138    validate_rotation_inputs(data.len(), sign_pattern.len())?;
139    match order {
140        RotationOrder::Forward => {
141            apply_sign_flip(data, sign_pattern);
142            wht_inplace(data);
143        }
144        RotationOrder::Inverse => {
145            wht_inplace(data);
146            apply_sign_flip(data, sign_pattern);
147        }
148    }
149    Ok(())
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    /// Small dimension used for rotation round-trip and acceptance tests.
157    const TEST_SMALL_DIM: usize = 4;
158    /// Seed used for deterministic sign-pattern generation in rotation tests.
159    const TEST_SEED: u64 = 42;
160    /// Dimension used for sign-pattern generation tests (must be a power of two).
161    const TEST_SIGN_PATTERN_DIM: usize = 128;
162    /// Seed used for deterministic sign-pattern generation tests.
163    const TEST_SIGN_PATTERN_SEED: u64 = 99;
164    /// Index used in golden_ratio_hash determinism test.
165    const TEST_HASH_INDEX: usize = 7;
166
167    // -- is_valid_dim --------------------------------------------------
168
169    #[test]
170    fn is_valid_dim_accepts_powers_of_two() {
171        assert!(is_valid_dim(1));
172        assert!(is_valid_dim(2));
173        assert!(is_valid_dim(64));
174        assert!(is_valid_dim(256));
175    }
176
177    #[test]
178    fn is_valid_dim_rejects_invalid() {
179        assert!(!is_valid_dim(0));
180        assert!(!is_valid_dim(3));
181        assert!(!is_valid_dim(100));
182    }
183
184    // -- golden_ratio_hash ---------------------------------------------------
185
186    #[test]
187    fn golden_ratio_hash_is_deterministic() {
188        let a = golden_ratio_hash(TEST_SEED, TEST_HASH_INDEX);
189        let b = golden_ratio_hash(TEST_SEED, TEST_HASH_INDEX);
190        assert_eq!(a, b);
191    }
192
193    // -- sign_pattern_elements_are_plus_or_minus_one -------------------------
194
195    #[test]
196    fn sign_pattern_elements_are_plus_or_minus_one() {
197        let pattern = generate_sign_pattern(TEST_SIGN_PATTERN_DIM, TEST_SIGN_PATTERN_SEED);
198        assert_sign_values_valid(&pattern);
199    }
200
201    /// Pure assertion helper: checks every element is +1.0 or -1.0.
202    ///
203    /// Separated from generation so the test is not mixing calls with loop logic.
204    fn assert_sign_values_valid(pattern: &[f32]) {
205        for &v in pattern {
206            assert!(v == 1.0 || v == -1.0);
207        }
208    }
209
210    // -- validate_rotation_args (via rotate / inverse_rotate) ----------------
211
212    #[test]
213    fn rotate_accepts_matching_pow2_dims() {
214        let mut data = vec![1.0, 2.0, 3.0, 4.0];
215        let signs = generate_sign_pattern(TEST_SMALL_DIM, TEST_SEED);
216        assert!(rotate(&mut data, &signs, RotationOrder::Forward).is_ok());
217    }
218
219    #[test]
220    fn rotate_rejects_non_pow2() {
221        let mut data = vec![1.0, 2.0, 3.0];
222        let signs = vec![1.0, -1.0, 1.0];
223        assert!(rotate(&mut data, &signs, RotationOrder::Forward).is_err());
224    }
225
226    #[test]
227    fn rotate_rejects_mismatched_lengths() {
228        let mut data = vec![1.0, 2.0, 3.0, 4.0];
229        let signs = vec![1.0, -1.0];
230        assert!(rotate(&mut data, &signs, RotationOrder::Forward).is_err());
231    }
232
233    // -- wht_inplace ---------------------------------------------------------
234
235    #[test]
236    fn wht_inplace_known_vector() {
237        // WHT of [1, 1, 1, 1] normalized by 1/sqrt(4) = 0.5
238        // -> [2, 0, 0, 0] * 0.5 = [1, 0, 0, 0]
239        let mut data = vec![1.0, 1.0, 1.0, 1.0];
240        wht_inplace(&mut data);
241        assert!((data[0] - 2.0).abs() < 1e-6);
242        assert!((data[1] - 0.0).abs() < 1e-6);
243        assert!((data[2] - 0.0).abs() < 1e-6);
244        assert!((data[3] - 0.0).abs() < 1e-6);
245    }
246
247    #[test]
248    fn wht_inplace_self_inverse() {
249        let original = vec![1.0, 2.0, 3.0, 4.0];
250        let mut data = original.clone();
251        wht_inplace(&mut data);
252        wht_inplace(&mut data);
253        for (a, b) in data.iter().zip(original.iter()) {
254            assert!((a - b).abs() < 1e-5);
255        }
256    }
257
258    // -- apply_sign_flip -----------------------------------------------------
259
260    #[test]
261    fn apply_sign_flip_basic() {
262        let mut data = vec![2.0, -3.0, 4.0, -5.0];
263        let signs = vec![1.0, -1.0, -1.0, 1.0];
264        apply_sign_flip(&mut data, &signs);
265        assert_eq!(data, vec![2.0, 3.0, -4.0, -5.0]);
266    }
267
268    // -- values_match (from error.rs) -----------------------------------------
269
270    #[test]
271    fn values_match_equal() {
272        assert!(values_match(4, 4));
273        assert!(values_match(128, 128));
274    }
275
276    #[test]
277    fn values_match_unequal() {
278        assert!(!values_match(4, 8));
279        assert!(!values_match(0, 1));
280    }
281
282    // -- roundtrip rotation --------------------------------------------------
283
284    #[test]
285    fn rotate_inverse_rotate_roundtrip() {
286        let original = vec![1.0, 2.0, 3.0, 4.0];
287        let mut data = original.clone();
288        let signs = generate_sign_pattern(TEST_SMALL_DIM, TEST_SEED);
289        rotate(&mut data, &signs, RotationOrder::Forward).unwrap();
290        rotate(&mut data, &signs, RotationOrder::Inverse).unwrap();
291        for (a, b) in data.iter().zip(original.iter()) {
292            assert!((a - b).abs() < 1e-5);
293        }
294    }
295}