Skip to main content

sochdb_vector/
rotation.rs

1//! Walsh-Hadamard rotation for embedding preprocessing.
2//!
3//! Applies a fast, structured rotation: x' = H(D ⊙ x)
4//! where D is random ±1 diagonal and H is a Hadamard-like transform.
5
6use rand::Rng;
7use rand::SeedableRng;
8use rand_xoshiro::Xoshiro256PlusPlus;
9
10/// Rotator for embedding preprocessing
11pub struct Rotator {
12    /// Original dimension
13    dim: u32,
14    /// Padded dimension (power of 2)
15    padded_dim: u32,
16    /// Random signs for diagonal D
17    signs: Vec<f32>,
18}
19
20impl Rotator {
21    /// Seed for deterministic rotations
22    const ROTATE_SEED: u64 = 0xDEAD_BEEF_CAFE_1234;
23
24    /// Create a new rotator for the given dimension
25    pub fn new(dim: u32) -> Self {
26        let padded_dim = Self::next_power_of_two(dim);
27
28        // Generate random signs
29        let mut rng = Xoshiro256PlusPlus::seed_from_u64(Self::ROTATE_SEED);
30        let signs: Vec<f32> = (0..padded_dim)
31            .map(|_| if rng.r#gen::<bool>() { 1.0 } else { -1.0 })
32            .collect();
33
34        Self {
35            dim,
36            padded_dim,
37            signs,
38        }
39    }
40
41    /// Find next power of two >= n
42    fn next_power_of_two(n: u32) -> u32 {
43        let mut p = 1u32;
44        while p < n {
45            p *= 2;
46        }
47        p
48    }
49
50    /// Apply rotation: x' = H(D ⊙ x)
51    pub fn rotate(&self, x: &[f32]) -> Vec<f32> {
52        assert!(x.len() <= self.padded_dim as usize);
53
54        // Pad with zeros if needed
55        let mut v = vec![0.0f32; self.padded_dim as usize];
56        for (i, &val) in x.iter().enumerate() {
57            v[i] = val * self.signs[i];
58        }
59
60        // Apply Walsh-Hadamard transform in-place
61        self.hadamard_transform(&mut v);
62
63        // Return only the original dimensions
64        v.truncate(self.dim as usize);
65
66        // Normalize
67        let norm_factor = 1.0 / (self.padded_dim as f32).sqrt();
68        for val in &mut v {
69            *val *= norm_factor;
70        }
71
72        v
73    }
74
75    /// Apply inverse rotation
76    pub fn rotate_inverse(&self, x: &[f32]) -> Vec<f32> {
77        // Hadamard is its own inverse (up to scaling)
78        // For inverse: x = D ⊙ H(x') * scale
79
80        let mut v = vec![0.0f32; self.padded_dim as usize];
81        for (i, &val) in x.iter().enumerate() {
82            v[i] = val;
83        }
84
85        // Apply Hadamard
86        self.hadamard_transform(&mut v);
87
88        // Apply inverse signs
89        for i in 0..self.dim as usize {
90            v[i] *= self.signs[i];
91        }
92
93        // Normalize
94        let norm_factor = 1.0 / (self.padded_dim as f32).sqrt();
95        for val in &mut v {
96            *val *= norm_factor;
97        }
98
99        v.truncate(self.dim as usize);
100        v
101    }
102
103    /// Fast Walsh-Hadamard transform in-place
104    fn hadamard_transform(&self, v: &mut [f32]) {
105        let n = v.len();
106        assert!(n.is_power_of_two());
107
108        let mut h = 1;
109        while h < n {
110            for i in (0..n).step_by(h * 2) {
111                for j in i..(i + h) {
112                    let x = v[j];
113                    let y = v[j + h];
114                    v[j] = x + y;
115                    v[j + h] = x - y;
116                }
117            }
118            h *= 2;
119        }
120    }
121
122    /// Get original dimension
123    pub fn dim(&self) -> u32 {
124        self.dim
125    }
126
127    /// Get padded dimension
128    pub fn padded_dim(&self) -> u32 {
129        self.padded_dim
130    }
131}
132
133/// Block-Hadamard for non-power-of-two dimensions
134/// Applies Hadamard to blocks of size 2^k
135pub struct BlockRotator {
136    dim: u32,
137    block_size: u32,
138    num_blocks: u32,
139    signs: Vec<f32>,
140}
141
142impl BlockRotator {
143    /// Seed for deterministic block rotations
144    const BLOCK_ROT_SEED: u64 = 0xB10C_B0A7_CAFE_5678;
145
146    /// Create a block rotator with specified block size
147    pub fn new(dim: u32, block_size: u32) -> Self {
148        assert!(block_size.is_power_of_two());
149        let num_blocks = (dim + block_size - 1) / block_size;
150
151        let mut rng = Xoshiro256PlusPlus::seed_from_u64(Self::BLOCK_ROT_SEED);
152        let total_size = num_blocks * block_size;
153        let signs: Vec<f32> = (0..total_size)
154            .map(|_| if rng.r#gen::<bool>() { 1.0 } else { -1.0 })
155            .collect();
156
157        Self {
158            dim,
159            block_size,
160            num_blocks,
161            signs,
162        }
163    }
164
165    /// Apply block rotation
166    pub fn rotate(&self, x: &[f32]) -> Vec<f32> {
167        let mut result = vec![0.0f32; self.dim as usize];
168
169        for block_idx in 0..self.num_blocks as usize {
170            let start = block_idx * self.block_size as usize;
171            let end = (start + self.block_size as usize).min(self.dim as usize);
172
173            // Pad block
174            let mut block = vec![0.0f32; self.block_size as usize];
175            for (i, idx) in (start..end).enumerate() {
176                if idx < x.len() {
177                    block[i] = x[idx] * self.signs[start + i];
178                }
179            }
180
181            // Apply Hadamard to block
182            Self::hadamard_transform_block(&mut block);
183
184            // Copy back
185            let norm = 1.0 / (self.block_size as f32).sqrt();
186            for (i, idx) in (start..end).enumerate() {
187                result[idx] = block[i] * norm;
188            }
189        }
190
191        result
192    }
193
194    fn hadamard_transform_block(v: &mut [f32]) {
195        let n = v.len();
196        let mut h = 1;
197        while h < n {
198            for i in (0..n).step_by(h * 2) {
199                for j in i..(i + h) {
200                    let x = v[j];
201                    let y = v[j + h];
202                    v[j] = x + y;
203                    v[j + h] = x - y;
204                }
205            }
206            h *= 2;
207        }
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_rotator_preserves_norm() {
217        let rotator = Rotator::new(64);
218
219        // Random vector
220        let x: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
221        let norm_before: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
222
223        let y = rotator.rotate(&x);
224        let norm_after: f32 = y.iter().map(|v| v * v).sum::<f32>().sqrt();
225
226        // Norms should be approximately equal
227        assert!(
228            (norm_before - norm_after).abs() < 0.01,
229            "Norms differ: {} vs {}",
230            norm_before,
231            norm_after
232        );
233    }
234
235    #[test]
236    fn test_rotation_roundtrip() {
237        let rotator = Rotator::new(64);
238
239        let x: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
240        let y = rotator.rotate(&x);
241        let z = rotator.rotate_inverse(&y);
242
243        // Should be approximately equal to original
244        for (a, b) in x.iter().zip(z.iter()) {
245            assert!((a - b).abs() < 0.01, "Mismatch: {} vs {}", a, b);
246        }
247    }
248
249    #[test]
250    fn test_hadamard_basic() {
251        let rotator = Rotator::new(4);
252
253        // Test with simple input
254        let x = vec![1.0, 0.0, 0.0, 0.0];
255        let y = rotator.rotate(&x);
256
257        // After Hadamard (with normalization), all components should have equal magnitude
258        assert!(y.iter().all(|&v| (v.abs() - y[0].abs()).abs() < 0.01));
259    }
260
261    #[test]
262    fn test_block_rotator() {
263        let rotator = BlockRotator::new(768, 64);
264
265        let x: Vec<f32> = (0..768).map(|i| (i as f32 - 384.0) * 0.01).collect();
266        let norm_before: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
267
268        let y = rotator.rotate(&x);
269        let norm_after: f32 = y.iter().map(|v| v * v).sum::<f32>().sqrt();
270
271        assert!((norm_before - norm_after).abs() < 0.1);
272    }
273}