1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
81use std::sync::{Arc, Mutex};
82use std::thread::{self, JoinHandle};
83
84#[derive(Debug, Clone)]
86pub struct RotationConfig {
87 pub num_workers: usize,
89
90 pub input_capacity: usize,
92
93 pub output_capacity: usize,
95
96 pub dim: usize,
98
99 pub batch_size: usize,
101}
102
103impl Default for RotationConfig {
104 fn default() -> Self {
105 Self {
106 num_workers: 4,
107 input_capacity: 1024,
108 output_capacity: 1024,
109 dim: 768,
110 batch_size: 16,
111 }
112 }
113}
114
115pub type VectorKey = u64;
117
118#[derive(Clone)]
120pub struct RotationInput {
121 pub key: VectorKey,
123
124 pub vector: Vec<f32>,
126
127 pub seq: u64,
129}
130
131#[derive(Clone)]
133pub struct RotationOutput {
134 pub key: VectorKey,
136
137 pub rotated: Vec<f32>,
139
140 pub seq: u64,
142
143 pub rotation_time_ns: u64,
145}
146
147#[derive(Debug, Clone, Default)]
149pub struct PipelineStats {
150 pub submitted: u64,
152
153 pub completed: u64,
155
156 pub total_rotation_ns: u64,
158
159 pub in_flight: u64,
161}
162
163impl PipelineStats {
164 pub fn avg_rotation_ns(&self) -> f64 {
166 if self.completed == 0 {
167 return 0.0;
168 }
169 self.total_rotation_ns as f64 / self.completed as f64
170 }
171
172 pub fn throughput(&self) -> f64 {
174 if self.total_rotation_ns == 0 {
175 return 0.0;
176 }
177 self.completed as f64 / (self.total_rotation_ns as f64 / 1e9)
178 }
179}
180
181struct BoundedChannel<T> {
183 buffer: Mutex<Vec<T>>,
184 capacity: usize,
185}
186
187impl<T> BoundedChannel<T> {
188 fn new(capacity: usize) -> Self {
189 Self {
190 buffer: Mutex::new(Vec::with_capacity(capacity)),
191 capacity,
192 }
193 }
194
195 fn try_push(&self, item: T) -> Result<(), T> {
196 let mut buffer = self.buffer.lock().unwrap();
197 if buffer.len() >= self.capacity {
198 return Err(item);
199 }
200 buffer.push(item);
201 Ok(())
202 }
203
204 #[allow(dead_code)]
205 fn push_single(&self, item: T) -> bool {
206 self.try_push(item).is_ok()
207 }
208
209 fn try_pop(&self) -> Option<T> {
210 let mut buffer = self.buffer.lock().unwrap();
211 buffer.pop()
212 }
213
214 fn try_pop_batch(&self, max: usize) -> Vec<T> {
215 let mut buffer = self.buffer.lock().unwrap();
216 let len = buffer.len();
217 let drain_count = len.min(max);
218 let start = len.saturating_sub(drain_count);
219 buffer.drain(start..).collect()
220 }
221
222 fn len(&self) -> usize {
223 self.buffer.lock().unwrap().len()
224 }
225}
226
227impl<T: Clone> BoundedChannel<T> {
228 fn push_blocking(&self, item: T) {
229 loop {
230 match self.try_push(item.clone()) {
231 Ok(()) => return,
232 Err(_) => {
233 std::thread::sleep(std::time::Duration::from_micros(10));
234 }
235 }
236 }
237 }
238}
239
240pub struct RotationPipeline {
242 #[allow(dead_code)]
244 config: RotationConfig,
245
246 input: Arc<BoundedChannel<RotationInput>>,
248
249 output: Arc<BoundedChannel<RotationOutput>>,
251
252 workers: Vec<JoinHandle<()>>,
254
255 shutdown: Arc<AtomicBool>,
257
258 seq_counter: AtomicU64,
260
261 stats: Arc<PipelineStatsInner>,
263}
264
265struct PipelineStatsInner {
266 submitted: AtomicU64,
267 completed: AtomicU64,
268 total_rotation_ns: AtomicU64,
269}
270
271impl RotationPipeline {
272 pub fn new(config: RotationConfig) -> Self {
274 let input = Arc::new(BoundedChannel::new(config.input_capacity));
275 let output = Arc::new(BoundedChannel::new(config.output_capacity));
276 let shutdown = Arc::new(AtomicBool::new(false));
277 let stats = Arc::new(PipelineStatsInner {
278 submitted: AtomicU64::new(0),
279 completed: AtomicU64::new(0),
280 total_rotation_ns: AtomicU64::new(0),
281 });
282
283 let mut workers = Vec::with_capacity(config.num_workers);
284
285 for _ in 0..config.num_workers {
286 let input = Arc::clone(&input);
287 let output = Arc::clone(&output);
288 let shutdown = Arc::clone(&shutdown);
289 let stats = Arc::clone(&stats);
290 let batch_size = config.batch_size;
291
292 let handle = thread::spawn(move || {
293 worker_loop(input, output, shutdown, stats, batch_size);
294 });
295
296 workers.push(handle);
297 }
298
299 Self {
300 config,
301 input,
302 output,
303 workers,
304 shutdown,
305 seq_counter: AtomicU64::new(0),
306 stats,
307 }
308 }
309
310 pub fn submit(&self, key: VectorKey, vector: Vec<f32>) {
312 let seq = self.seq_counter.fetch_add(1, Ordering::Relaxed);
313
314 let input = RotationInput { key, vector, seq };
315 self.input.push_blocking(input);
316
317 self.stats.submitted.fetch_add(1, Ordering::Relaxed);
318 }
319
320 pub fn submit_batch(&self, items: Vec<(VectorKey, Vec<f32>)>) {
322 for (key, vector) in items {
323 self.submit(key, vector);
324 }
325 }
326
327 pub fn try_recv(&self) -> Option<RotationOutput> {
329 self.output.try_pop()
330 }
331
332 pub fn recv(&self) -> Option<RotationOutput> {
334 loop {
335 if let Some(output) = self.output.try_pop() {
336 return Some(output);
337 }
338
339 if self.shutdown.load(Ordering::Acquire) && self.input.len() == 0 {
340 return self.output.try_pop();
342 }
343
344 std::thread::sleep(std::time::Duration::from_micros(10));
345 }
346 }
347
348 pub fn recv_batch(&self, max: usize) -> Vec<RotationOutput> {
350 self.output.try_pop_batch(max)
351 }
352
353 pub fn stats(&self) -> PipelineStats {
355 let submitted = self.stats.submitted.load(Ordering::Relaxed);
356 let completed = self.stats.completed.load(Ordering::Relaxed);
357
358 PipelineStats {
359 submitted,
360 completed,
361 total_rotation_ns: self.stats.total_rotation_ns.load(Ordering::Relaxed),
362 in_flight: submitted.saturating_sub(completed),
363 }
364 }
365
366 pub fn flush(&self) -> Vec<RotationOutput> {
368 let mut results = Vec::new();
369
370 loop {
372 let stats = self.stats();
373
374 if stats.completed >= stats.submitted {
375 break;
376 }
377
378 results.extend(self.recv_batch(64));
380
381 std::thread::sleep(std::time::Duration::from_micros(100));
382 }
383
384 results.extend(self.recv_batch(1024));
386
387 results
388 }
389
390 pub fn shutdown(mut self) -> Vec<RotationOutput> {
392 self.shutdown.store(true, Ordering::Release);
393
394 for handle in self.workers.drain(..) {
396 let _ = handle.join();
397 }
398
399 let mut results = Vec::new();
401 while let Some(output) = self.output.try_pop() {
402 results.push(output);
403 }
404
405 results
406 }
407}
408
409fn worker_loop(
411 input: Arc<BoundedChannel<RotationInput>>,
412 output: Arc<BoundedChannel<RotationOutput>>,
413 shutdown: Arc<AtomicBool>,
414 stats: Arc<PipelineStatsInner>,
415 batch_size: usize,
416) {
417 loop {
418 let batch = input.try_pop_batch(batch_size);
420
421 if batch.is_empty() {
422 if shutdown.load(Ordering::Acquire) {
423 break;
424 }
425 std::thread::sleep(std::time::Duration::from_micros(10));
426 continue;
427 }
428
429 for item in batch {
430 let start = std::time::Instant::now();
431
432 let mut rotated = item.vector;
434 hadamard_transform(&mut rotated);
435
436 let rotation_time_ns = start.elapsed().as_nanos() as u64;
437
438 let result = RotationOutput {
439 key: item.key,
440 rotated,
441 seq: item.seq,
442 rotation_time_ns,
443 };
444
445 output.push_blocking(result);
446
447 stats.completed.fetch_add(1, Ordering::Relaxed);
448 stats
449 .total_rotation_ns
450 .fetch_add(rotation_time_ns, Ordering::Relaxed);
451 }
452 }
453}
454
455pub fn hadamard_transform(data: &mut [f32]) {
463 let n = data.len();
464 if n == 0 {
465 return;
466 }
467
468 let n_pow2 = n.next_power_of_two();
471 if n_pow2 != n {
472 normalize_vector(data);
474 return;
475 }
476
477 let mut h = 1;
478 while h < n {
479 for i in (0..n).step_by(h * 2) {
480 for j in i..(i + h) {
481 let x = data[j];
482 let y = data[j + h];
483 data[j] = x + y;
484 data[j + h] = x - y;
485 }
486 }
487 h *= 2;
488 }
489
490 let scale = 1.0 / (n as f32).sqrt();
492 for x in data.iter_mut() {
493 *x *= scale;
494 }
495}
496
497fn normalize_vector(data: &mut [f32]) {
499 let norm: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
500 if norm > 1e-10 {
501 for x in data.iter_mut() {
502 *x /= norm;
503 }
504 }
505}
506
507pub struct SyncRotator {
513 #[allow(dead_code)]
515 buffer: Vec<f32>,
516}
517
518impl SyncRotator {
519 pub fn new(dim: usize) -> Self {
521 Self {
522 buffer: vec![0.0; dim],
523 }
524 }
525
526 pub fn rotate_inplace(&self, data: &mut [f32]) {
528 hadamard_transform(data);
529 }
530
531 pub fn rotate(&self, vector: &[f32]) -> Vec<f32> {
533 let mut rotated = vector.to_vec();
534 hadamard_transform(&mut rotated);
535 rotated
536 }
537
538 pub fn rotate_batch(&self, vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
540 vectors.iter().map(|v| self.rotate(v)).collect()
541 }
542
543 pub fn rotate_batch_flat(&self, flat_data: &mut [f32], dim: usize) {
545 let num_vectors = flat_data.len() / dim;
546
547 for i in 0..num_vectors {
548 let start = i * dim;
549 let slice = &mut flat_data[start..start + dim];
550 hadamard_transform(slice);
551 }
552 }
553}
554
555impl Default for SyncRotator {
556 fn default() -> Self {
557 Self::new(768)
558 }
559}
560
561#[cfg(test)]
566mod tests {
567 use super::*;
568
569 #[test]
570 fn test_hadamard_basic() {
571 let mut data = vec![1.0, 0.0, 0.0, 0.0];
572 hadamard_transform(&mut data);
573
574 for &x in &data {
576 assert!((x - 0.5).abs() < 0.01, "x = {}", x);
577 }
578 }
579
580 #[test]
581 fn test_hadamard_preserves_norm() {
582 let mut data: Vec<f32> = (0..16).map(|i| i as f32 / 16.0).collect();
583 let original_norm: f32 = data.iter().map(|x| x * x).sum();
584
585 hadamard_transform(&mut data);
586
587 let transformed_norm: f32 = data.iter().map(|x| x * x).sum();
588
589 assert!(
591 (original_norm - transformed_norm).abs() < 0.01,
592 "norm changed: {} -> {}",
593 original_norm,
594 transformed_norm
595 );
596 }
597
598 #[test]
599 fn test_sync_rotator() {
600 let rotator = SyncRotator::new(4);
601
602 let vector = vec![1.0, 2.0, 3.0, 4.0];
603 let rotated = rotator.rotate(&vector);
604
605 assert_eq!(rotated.len(), 4);
606
607 assert_eq!(vector, vec![1.0, 2.0, 3.0, 4.0]);
609 }
610
611 #[test]
612 fn test_pipeline_basic() {
613 let config = RotationConfig {
614 num_workers: 2,
615 input_capacity: 16,
616 output_capacity: 16,
617 dim: 4,
618 batch_size: 4,
619 };
620
621 let pipeline = RotationPipeline::new(config);
622
623 for i in 0..10 {
625 let vector = vec![i as f32; 4];
626 pipeline.submit(i, vector);
627 }
628
629 let results = pipeline.flush();
631
632 assert_eq!(results.len(), 10);
633 }
634
635 #[test]
636 fn test_pipeline_ordering() {
637 let config = RotationConfig {
638 num_workers: 1, input_capacity: 32,
640 output_capacity: 32,
641 dim: 4,
642 batch_size: 1,
643 };
644
645 let pipeline = RotationPipeline::new(config);
646
647 for i in 0..5 {
649 pipeline.submit(i as u64, vec![i as f32; 4]);
650 }
651
652 let mut results = pipeline.flush();
654 results.sort_by_key(|r| r.seq);
655
656 for (i, result) in results.iter().enumerate() {
658 assert_eq!(result.key, i as u64);
659 }
660 }
661
662 #[test]
663 fn test_pipeline_stats() {
664 let config = RotationConfig::default();
665 let pipeline = RotationPipeline::new(config);
666
667 for i in 0..5 {
669 pipeline.submit(i, vec![0.0; 768]);
670 }
671
672 let initial_stats = pipeline.stats();
673 assert_eq!(initial_stats.submitted, 5);
674
675 let _ = pipeline.flush();
677
678 let final_stats = pipeline.stats();
679 assert_eq!(final_stats.completed, 5);
680 assert!(final_stats.total_rotation_ns > 0);
681 }
682
683 #[test]
684 fn test_pipeline_shutdown() {
685 let config = RotationConfig {
686 num_workers: 2,
687 dim: 4,
688 ..Default::default()
689 };
690
691 let pipeline = RotationPipeline::new(config);
692
693 pipeline.submit(1, vec![1.0; 4]);
694 pipeline.submit(2, vec![2.0; 4]);
695
696 let results = pipeline.shutdown();
697
698 assert!(results.len() <= 2); }
700}