sochdb_vector/
rotation.rs1use rand::Rng;
7use rand::SeedableRng;
8use rand_xoshiro::Xoshiro256PlusPlus;
9
10pub struct Rotator {
12 dim: u32,
14 padded_dim: u32,
16 signs: Vec<f32>,
18}
19
20impl Rotator {
21 const ROTATE_SEED: u64 = 0xDEAD_BEEF_CAFE_1234;
23
24 pub fn new(dim: u32) -> Self {
26 let padded_dim = Self::next_power_of_two(dim);
27
28 let mut rng = Xoshiro256PlusPlus::seed_from_u64(Self::ROTATE_SEED);
30 let signs: Vec<f32> = (0..padded_dim)
31 .map(|_| if rng.r#gen::<bool>() { 1.0 } else { -1.0 })
32 .collect();
33
34 Self {
35 dim,
36 padded_dim,
37 signs,
38 }
39 }
40
41 fn next_power_of_two(n: u32) -> u32 {
43 let mut p = 1u32;
44 while p < n {
45 p *= 2;
46 }
47 p
48 }
49
50 pub fn rotate(&self, x: &[f32]) -> Vec<f32> {
52 assert!(x.len() <= self.padded_dim as usize);
53
54 let mut v = vec![0.0f32; self.padded_dim as usize];
56 for (i, &val) in x.iter().enumerate() {
57 v[i] = val * self.signs[i];
58 }
59
60 self.hadamard_transform(&mut v);
62
63 v.truncate(self.dim as usize);
65
66 let norm_factor = 1.0 / (self.padded_dim as f32).sqrt();
68 for val in &mut v {
69 *val *= norm_factor;
70 }
71
72 v
73 }
74
75 pub fn rotate_inverse(&self, x: &[f32]) -> Vec<f32> {
77 let mut v = vec![0.0f32; self.padded_dim as usize];
81 for (i, &val) in x.iter().enumerate() {
82 v[i] = val;
83 }
84
85 self.hadamard_transform(&mut v);
87
88 for i in 0..self.dim as usize {
90 v[i] *= self.signs[i];
91 }
92
93 let norm_factor = 1.0 / (self.padded_dim as f32).sqrt();
95 for val in &mut v {
96 *val *= norm_factor;
97 }
98
99 v.truncate(self.dim as usize);
100 v
101 }
102
103 fn hadamard_transform(&self, v: &mut [f32]) {
105 let n = v.len();
106 assert!(n.is_power_of_two());
107
108 let mut h = 1;
109 while h < n {
110 for i in (0..n).step_by(h * 2) {
111 for j in i..(i + h) {
112 let x = v[j];
113 let y = v[j + h];
114 v[j] = x + y;
115 v[j + h] = x - y;
116 }
117 }
118 h *= 2;
119 }
120 }
121
122 pub fn dim(&self) -> u32 {
124 self.dim
125 }
126
127 pub fn padded_dim(&self) -> u32 {
129 self.padded_dim
130 }
131}
132
133pub struct BlockRotator {
136 dim: u32,
137 block_size: u32,
138 num_blocks: u32,
139 signs: Vec<f32>,
140}
141
142impl BlockRotator {
143 const BLOCK_ROT_SEED: u64 = 0xB10C_B0A7_CAFE_5678;
145
146 pub fn new(dim: u32, block_size: u32) -> Self {
148 assert!(block_size.is_power_of_two());
149 let num_blocks = (dim + block_size - 1) / block_size;
150
151 let mut rng = Xoshiro256PlusPlus::seed_from_u64(Self::BLOCK_ROT_SEED);
152 let total_size = num_blocks * block_size;
153 let signs: Vec<f32> = (0..total_size)
154 .map(|_| if rng.r#gen::<bool>() { 1.0 } else { -1.0 })
155 .collect();
156
157 Self {
158 dim,
159 block_size,
160 num_blocks,
161 signs,
162 }
163 }
164
165 pub fn rotate(&self, x: &[f32]) -> Vec<f32> {
167 let mut result = vec![0.0f32; self.dim as usize];
168
169 for block_idx in 0..self.num_blocks as usize {
170 let start = block_idx * self.block_size as usize;
171 let end = (start + self.block_size as usize).min(self.dim as usize);
172
173 let mut block = vec![0.0f32; self.block_size as usize];
175 for (i, idx) in (start..end).enumerate() {
176 if idx < x.len() {
177 block[i] = x[idx] * self.signs[start + i];
178 }
179 }
180
181 Self::hadamard_transform_block(&mut block);
183
184 let norm = 1.0 / (self.block_size as f32).sqrt();
186 for (i, idx) in (start..end).enumerate() {
187 result[idx] = block[i] * norm;
188 }
189 }
190
191 result
192 }
193
194 fn hadamard_transform_block(v: &mut [f32]) {
195 let n = v.len();
196 let mut h = 1;
197 while h < n {
198 for i in (0..n).step_by(h * 2) {
199 for j in i..(i + h) {
200 let x = v[j];
201 let y = v[j + h];
202 v[j] = x + y;
203 v[j + h] = x - y;
204 }
205 }
206 h *= 2;
207 }
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_rotator_preserves_norm() {
217 let rotator = Rotator::new(64);
218
219 let x: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
221 let norm_before: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
222
223 let y = rotator.rotate(&x);
224 let norm_after: f32 = y.iter().map(|v| v * v).sum::<f32>().sqrt();
225
226 assert!(
228 (norm_before - norm_after).abs() < 0.01,
229 "Norms differ: {} vs {}",
230 norm_before,
231 norm_after
232 );
233 }
234
235 #[test]
236 fn test_rotation_roundtrip() {
237 let rotator = Rotator::new(64);
238
239 let x: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
240 let y = rotator.rotate(&x);
241 let z = rotator.rotate_inverse(&y);
242
243 for (a, b) in x.iter().zip(z.iter()) {
245 assert!((a - b).abs() < 0.01, "Mismatch: {} vs {}", a, b);
246 }
247 }
248
249 #[test]
250 fn test_hadamard_basic() {
251 let rotator = Rotator::new(4);
252
253 let x = vec![1.0, 0.0, 0.0, 0.0];
255 let y = rotator.rotate(&x);
256
257 assert!(y.iter().all(|&v| (v.abs() - y[0].abs()).abs() < 0.01));
259 }
260
261 #[test]
262 fn test_block_rotator() {
263 let rotator = BlockRotator::new(768, 64);
264
265 let x: Vec<f32> = (0..768).map(|i| (i as f32 - 384.0) * 0.01).collect();
266 let norm_before: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
267
268 let y = rotator.rotate(&x);
269 let norm_after: f32 = y.iter().map(|v| v * v).sum::<f32>().sqrt();
270
271 assert!((norm_before - norm_after).abs() < 0.1);
272 }
273}