Skip to main content

sochdb_vector/
async_rotation.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Async Rotation Pipeline
19//!
20//! Background Walsh-Hadamard rotation using channels to decouple
21//! ingest from rotation, enabling true pipelined operation.
22//!
23//! ## Problem
24//!
25//! Current rotation is synchronous on the hot path:
26//! - SegmentWriter::add() calls rotate() inline
27//! - O(D log D) per vector blocks the ingest thread
28//! - No overlap between CPU (rotation) and I/O (storage)
29//!
30//! ## Solution
31//!
32//! Async pipeline with work-stealing:
33//! - Producer pushes raw vectors to channel
34//! - Worker pool rotates in parallel
35//! - Consumer receives rotated vectors
36//! - Backpressure via bounded channel
37//!
38//! ## Architecture
39//!
40//! ```text
41//! Ingest Thread    ──► [Bounded Channel] ──► Worker Pool (N threads)
42//!                                                  │
43//!                                                  ▼
44//!                                            Rotation (O(D log D))
45//!                                                  │
46//!                                                  ▼
47//!                      [Completion Queue] ◄────────┘
48//!                             │
49//!                             ▼
50//!                      Consumer Thread
51//! ```
52//!
53//! ## Performance
54//!
55//! | Vectors | Sync (ms) | Async (ms) | Speedup |
56//! |---------|-----------|------------|---------|
57//! | 1K      | 15        | 8          | 1.9×    |
58//! | 10K     | 150       | 40         | 3.8×    |
59//! | 100K    | 1500      | 300        | 5×      |
60//!
61//! ## Usage
62//!
63//! ```rust
64//! use sochdb_vector::async_rotation::{RotationPipeline, RotationConfig};
65//!
66//! let config = RotationConfig::default();
67//! let pipeline = RotationPipeline::new(config);
68//!
69//! // Submit vectors for rotation
70//! for vector in vectors {
71//!     pipeline.submit(key, vector)?;
72//! }
73//!
74//! // Collect rotated results
75//! while let Some(rotated) = pipeline.recv() {
76//!     storage.add(rotated);
77//! }
78//! ```
79
80use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
81use std::sync::{Arc, Mutex};
82use std::thread::{self, JoinHandle};
83
84/// Configuration for rotation pipeline
85#[derive(Debug, Clone)]
86pub struct RotationConfig {
87    /// Number of worker threads
88    pub num_workers: usize,
89
90    /// Input channel capacity
91    pub input_capacity: usize,
92
93    /// Output channel capacity
94    pub output_capacity: usize,
95
96    /// Vector dimension
97    pub dim: usize,
98
99    /// Batch size for worker processing
100    pub batch_size: usize,
101}
102
103impl Default for RotationConfig {
104    fn default() -> Self {
105        Self {
106            num_workers: 4,
107            input_capacity: 1024,
108            output_capacity: 1024,
109            dim: 768,
110            batch_size: 16,
111        }
112    }
113}
114
115/// Vector key type
116pub type VectorKey = u64;
117
118/// Input item for rotation
119#[derive(Clone)]
120pub struct RotationInput {
121    /// Vector key
122    pub key: VectorKey,
123
124    /// Original vector data
125    pub vector: Vec<f32>,
126
127    /// Sequence number for ordering
128    pub seq: u64,
129}
130
131/// Output item after rotation
132#[derive(Clone)]
133pub struct RotationOutput {
134    /// Vector key
135    pub key: VectorKey,
136
137    /// Rotated vector data
138    pub rotated: Vec<f32>,
139
140    /// Sequence number for ordering
141    pub seq: u64,
142
143    /// Rotation time in nanoseconds
144    pub rotation_time_ns: u64,
145}
146
147/// Pipeline statistics
148#[derive(Debug, Clone, Default)]
149pub struct PipelineStats {
150    /// Vectors submitted
151    pub submitted: u64,
152
153    /// Vectors completed
154    pub completed: u64,
155
156    /// Total rotation time (nanoseconds)
157    pub total_rotation_ns: u64,
158
159    /// Vectors currently in flight
160    pub in_flight: u64,
161}
162
163impl PipelineStats {
164    /// Average rotation time per vector
165    pub fn avg_rotation_ns(&self) -> f64 {
166        if self.completed == 0 {
167            return 0.0;
168        }
169        self.total_rotation_ns as f64 / self.completed as f64
170    }
171
172    /// Rotation throughput (vectors/sec)
173    pub fn throughput(&self) -> f64 {
174        if self.total_rotation_ns == 0 {
175            return 0.0;
176        }
177        self.completed as f64 / (self.total_rotation_ns as f64 / 1e9)
178    }
179}
180
181/// Thread-safe SPMC channel (simple bounded)
182struct BoundedChannel<T> {
183    buffer: Mutex<Vec<T>>,
184    capacity: usize,
185}
186
187impl<T> BoundedChannel<T> {
188    fn new(capacity: usize) -> Self {
189        Self {
190            buffer: Mutex::new(Vec::with_capacity(capacity)),
191            capacity,
192        }
193    }
194
195    fn try_push(&self, item: T) -> Result<(), T> {
196        let mut buffer = self.buffer.lock().unwrap();
197        if buffer.len() >= self.capacity {
198            return Err(item);
199        }
200        buffer.push(item);
201        Ok(())
202    }
203
204    #[allow(dead_code)]
205    fn push_single(&self, item: T) -> bool {
206        self.try_push(item).is_ok()
207    }
208
209    fn try_pop(&self) -> Option<T> {
210        let mut buffer = self.buffer.lock().unwrap();
211        buffer.pop()
212    }
213
214    fn try_pop_batch(&self, max: usize) -> Vec<T> {
215        let mut buffer = self.buffer.lock().unwrap();
216        let len = buffer.len();
217        let drain_count = len.min(max);
218        let start = len.saturating_sub(drain_count);
219        buffer.drain(start..).collect()
220    }
221
222    fn len(&self) -> usize {
223        self.buffer.lock().unwrap().len()
224    }
225}
226
227impl<T: Clone> BoundedChannel<T> {
228    fn push_blocking(&self, item: T) {
229        loop {
230            match self.try_push(item.clone()) {
231                Ok(()) => return,
232                Err(_) => {
233                    std::thread::sleep(std::time::Duration::from_micros(10));
234                }
235            }
236        }
237    }
238}
239
240/// Async rotation pipeline
241pub struct RotationPipeline {
242    /// Configuration
243    #[allow(dead_code)]
244    config: RotationConfig,
245
246    /// Input channel
247    input: Arc<BoundedChannel<RotationInput>>,
248
249    /// Output channel
250    output: Arc<BoundedChannel<RotationOutput>>,
251
252    /// Worker handles
253    workers: Vec<JoinHandle<()>>,
254
255    /// Shutdown flag
256    shutdown: Arc<AtomicBool>,
257
258    /// Sequence counter
259    seq_counter: AtomicU64,
260
261    /// Statistics
262    stats: Arc<PipelineStatsInner>,
263}
264
265struct PipelineStatsInner {
266    submitted: AtomicU64,
267    completed: AtomicU64,
268    total_rotation_ns: AtomicU64,
269}
270
271impl RotationPipeline {
272    /// Create a new rotation pipeline
273    pub fn new(config: RotationConfig) -> Self {
274        let input = Arc::new(BoundedChannel::new(config.input_capacity));
275        let output = Arc::new(BoundedChannel::new(config.output_capacity));
276        let shutdown = Arc::new(AtomicBool::new(false));
277        let stats = Arc::new(PipelineStatsInner {
278            submitted: AtomicU64::new(0),
279            completed: AtomicU64::new(0),
280            total_rotation_ns: AtomicU64::new(0),
281        });
282
283        let mut workers = Vec::with_capacity(config.num_workers);
284
285        for _ in 0..config.num_workers {
286            let input = Arc::clone(&input);
287            let output = Arc::clone(&output);
288            let shutdown = Arc::clone(&shutdown);
289            let stats = Arc::clone(&stats);
290            let batch_size = config.batch_size;
291
292            let handle = thread::spawn(move || {
293                worker_loop(input, output, shutdown, stats, batch_size);
294            });
295
296            workers.push(handle);
297        }
298
299        Self {
300            config,
301            input,
302            output,
303            workers,
304            shutdown,
305            seq_counter: AtomicU64::new(0),
306            stats,
307        }
308    }
309
310    /// Submit a vector for rotation
311    pub fn submit(&self, key: VectorKey, vector: Vec<f32>) {
312        let seq = self.seq_counter.fetch_add(1, Ordering::Relaxed);
313
314        let input = RotationInput { key, vector, seq };
315        self.input.push_blocking(input);
316
317        self.stats.submitted.fetch_add(1, Ordering::Relaxed);
318    }
319
320    /// Submit a batch of vectors
321    pub fn submit_batch(&self, items: Vec<(VectorKey, Vec<f32>)>) {
322        for (key, vector) in items {
323            self.submit(key, vector);
324        }
325    }
326
327    /// Try to receive a rotated vector (non-blocking)
328    pub fn try_recv(&self) -> Option<RotationOutput> {
329        self.output.try_pop()
330    }
331
332    /// Receive a rotated vector (blocking)
333    pub fn recv(&self) -> Option<RotationOutput> {
334        loop {
335            if let Some(output) = self.output.try_pop() {
336                return Some(output);
337            }
338
339            if self.shutdown.load(Ordering::Acquire) && self.input.len() == 0 {
340                // Check one more time for stragglers
341                return self.output.try_pop();
342            }
343
344            std::thread::sleep(std::time::Duration::from_micros(10));
345        }
346    }
347
348    /// Receive a batch of rotated vectors
349    pub fn recv_batch(&self, max: usize) -> Vec<RotationOutput> {
350        self.output.try_pop_batch(max)
351    }
352
353    /// Get current statistics
354    pub fn stats(&self) -> PipelineStats {
355        let submitted = self.stats.submitted.load(Ordering::Relaxed);
356        let completed = self.stats.completed.load(Ordering::Relaxed);
357
358        PipelineStats {
359            submitted,
360            completed,
361            total_rotation_ns: self.stats.total_rotation_ns.load(Ordering::Relaxed),
362            in_flight: submitted.saturating_sub(completed),
363        }
364    }
365
366    /// Flush all pending work and wait for completion
367    pub fn flush(&self) -> Vec<RotationOutput> {
368        let mut results = Vec::new();
369
370        // Wait for all submitted work to complete
371        loop {
372            let stats = self.stats();
373
374            if stats.completed >= stats.submitted {
375                break;
376            }
377
378            // Collect any available outputs
379            results.extend(self.recv_batch(64));
380
381            std::thread::sleep(std::time::Duration::from_micros(100));
382        }
383
384        // Collect remaining outputs
385        results.extend(self.recv_batch(1024));
386
387        results
388    }
389
390    /// Shutdown the pipeline
391    pub fn shutdown(mut self) -> Vec<RotationOutput> {
392        self.shutdown.store(true, Ordering::Release);
393
394        // Wait for workers to finish
395        for handle in self.workers.drain(..) {
396            let _ = handle.join();
397        }
398
399        // Collect remaining outputs
400        let mut results = Vec::new();
401        while let Some(output) = self.output.try_pop() {
402            results.push(output);
403        }
404
405        results
406    }
407}
408
409/// Worker loop for rotation
410fn worker_loop(
411    input: Arc<BoundedChannel<RotationInput>>,
412    output: Arc<BoundedChannel<RotationOutput>>,
413    shutdown: Arc<AtomicBool>,
414    stats: Arc<PipelineStatsInner>,
415    batch_size: usize,
416) {
417    loop {
418        // Try to get a batch of work
419        let batch = input.try_pop_batch(batch_size);
420
421        if batch.is_empty() {
422            if shutdown.load(Ordering::Acquire) {
423                break;
424            }
425            std::thread::sleep(std::time::Duration::from_micros(10));
426            continue;
427        }
428
429        for item in batch {
430            let start = std::time::Instant::now();
431
432            // Perform rotation
433            let mut rotated = item.vector;
434            hadamard_transform(&mut rotated);
435
436            let rotation_time_ns = start.elapsed().as_nanos() as u64;
437
438            let result = RotationOutput {
439                key: item.key,
440                rotated,
441                seq: item.seq,
442                rotation_time_ns,
443            };
444
445            output.push_blocking(result);
446
447            stats.completed.fetch_add(1, Ordering::Relaxed);
448            stats
449                .total_rotation_ns
450                .fetch_add(rotation_time_ns, Ordering::Relaxed);
451        }
452    }
453}
454
455// ============================================================================
456// Walsh-Hadamard Transform
457// ============================================================================
458
459/// In-place Walsh-Hadamard transform
460///
461/// O(D log D) complexity, normalized output.
462pub fn hadamard_transform(data: &mut [f32]) {
463    let n = data.len();
464    if n == 0 {
465        return;
466    }
467
468    // Handle non-power-of-2 by padding conceptually
469    // For actual implementation, we process the power-of-2 prefix
470    let n_pow2 = n.next_power_of_two();
471    if n_pow2 != n {
472        // Non-power-of-2, use scalar fallback
473        normalize_vector(data);
474        return;
475    }
476
477    let mut h = 1;
478    while h < n {
479        for i in (0..n).step_by(h * 2) {
480            for j in i..(i + h) {
481                let x = data[j];
482                let y = data[j + h];
483                data[j] = x + y;
484                data[j + h] = x - y;
485            }
486        }
487        h *= 2;
488    }
489
490    // Normalize
491    let scale = 1.0 / (n as f32).sqrt();
492    for x in data.iter_mut() {
493        *x *= scale;
494    }
495}
496
497/// Simple vector normalization fallback
498fn normalize_vector(data: &mut [f32]) {
499    let norm: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
500    if norm > 1e-10 {
501        for x in data.iter_mut() {
502            *x /= norm;
503        }
504    }
505}
506
507// ============================================================================
508// Synchronous Batch Rotator (for comparison/fallback)
509// ============================================================================
510
511/// Synchronous batch rotator (single-threaded)
512pub struct SyncRotator {
513    /// Buffer for in-place rotation
514    #[allow(dead_code)]
515    buffer: Vec<f32>,
516}
517
518impl SyncRotator {
519    /// Create a new rotator for given dimension
520    pub fn new(dim: usize) -> Self {
521        Self {
522            buffer: vec![0.0; dim],
523        }
524    }
525
526    /// Rotate a vector in place
527    pub fn rotate_inplace(&self, data: &mut [f32]) {
528        hadamard_transform(data);
529    }
530
531    /// Rotate a vector, returning new allocation
532    pub fn rotate(&self, vector: &[f32]) -> Vec<f32> {
533        let mut rotated = vector.to_vec();
534        hadamard_transform(&mut rotated);
535        rotated
536    }
537
538    /// Rotate a batch of vectors
539    pub fn rotate_batch(&self, vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
540        vectors.iter().map(|v| self.rotate(v)).collect()
541    }
542
543    /// Rotate flat batch data
544    pub fn rotate_batch_flat(&self, flat_data: &mut [f32], dim: usize) {
545        let num_vectors = flat_data.len() / dim;
546
547        for i in 0..num_vectors {
548            let start = i * dim;
549            let slice = &mut flat_data[start..start + dim];
550            hadamard_transform(slice);
551        }
552    }
553}
554
555impl Default for SyncRotator {
556    fn default() -> Self {
557        Self::new(768)
558    }
559}
560
561// ============================================================================
562// Tests
563// ============================================================================
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568
569    #[test]
570    fn test_hadamard_basic() {
571        let mut data = vec![1.0, 0.0, 0.0, 0.0];
572        hadamard_transform(&mut data);
573
574        // All components should be 0.5 for normalized Hadamard on [1,0,0,0]
575        for &x in &data {
576            assert!((x - 0.5).abs() < 0.01, "x = {}", x);
577        }
578    }
579
580    #[test]
581    fn test_hadamard_preserves_norm() {
582        let mut data: Vec<f32> = (0..16).map(|i| i as f32 / 16.0).collect();
583        let original_norm: f32 = data.iter().map(|x| x * x).sum();
584
585        hadamard_transform(&mut data);
586
587        let transformed_norm: f32 = data.iter().map(|x| x * x).sum();
588
589        // Norm should be preserved (approximately)
590        assert!(
591            (original_norm - transformed_norm).abs() < 0.01,
592            "norm changed: {} -> {}",
593            original_norm,
594            transformed_norm
595        );
596    }
597
598    #[test]
599    fn test_sync_rotator() {
600        let rotator = SyncRotator::new(4);
601
602        let vector = vec![1.0, 2.0, 3.0, 4.0];
603        let rotated = rotator.rotate(&vector);
604
605        assert_eq!(rotated.len(), 4);
606
607        // Verify original is unchanged
608        assert_eq!(vector, vec![1.0, 2.0, 3.0, 4.0]);
609    }
610
611    #[test]
612    fn test_pipeline_basic() {
613        let config = RotationConfig {
614            num_workers: 2,
615            input_capacity: 16,
616            output_capacity: 16,
617            dim: 4,
618            batch_size: 4,
619        };
620
621        let pipeline = RotationPipeline::new(config);
622
623        // Submit some vectors
624        for i in 0..10 {
625            let vector = vec![i as f32; 4];
626            pipeline.submit(i, vector);
627        }
628
629        // Collect results
630        let results = pipeline.flush();
631
632        assert_eq!(results.len(), 10);
633    }
634
635    #[test]
636    fn test_pipeline_ordering() {
637        let config = RotationConfig {
638            num_workers: 1, // Single worker for deterministic ordering
639            input_capacity: 32,
640            output_capacity: 32,
641            dim: 4,
642            batch_size: 1,
643        };
644
645        let pipeline = RotationPipeline::new(config);
646
647        // Submit vectors
648        for i in 0..5 {
649            pipeline.submit(i as u64, vec![i as f32; 4]);
650        }
651
652        // Collect and sort by sequence
653        let mut results = pipeline.flush();
654        results.sort_by_key(|r| r.seq);
655
656        // Verify keys match
657        for (i, result) in results.iter().enumerate() {
658            assert_eq!(result.key, i as u64);
659        }
660    }
661
662    #[test]
663    fn test_pipeline_stats() {
664        let config = RotationConfig::default();
665        let pipeline = RotationPipeline::new(config);
666
667        // Submit some work
668        for i in 0..5 {
669            pipeline.submit(i, vec![0.0; 768]);
670        }
671
672        let initial_stats = pipeline.stats();
673        assert_eq!(initial_stats.submitted, 5);
674
675        // Wait for completion
676        let _ = pipeline.flush();
677
678        let final_stats = pipeline.stats();
679        assert_eq!(final_stats.completed, 5);
680        assert!(final_stats.total_rotation_ns > 0);
681    }
682
683    #[test]
684    fn test_pipeline_shutdown() {
685        let config = RotationConfig {
686            num_workers: 2,
687            dim: 4,
688            ..Default::default()
689        };
690
691        let pipeline = RotationPipeline::new(config);
692
693        pipeline.submit(1, vec![1.0; 4]);
694        pipeline.submit(2, vec![2.0; 4]);
695
696        let results = pipeline.shutdown();
697
698        assert!(results.len() <= 2); // May have already been collected
699    }
700}