Skip to main content

sochdb_vector/
dispatch.rs

1//! Kernel dispatch for pure Rust SIMD implementations.
2//!
3//! Provides wrappers around optimized Rust SIMD kernels from the `simd` module.
4//! Uses runtime CPU detection to select optimal code path.
5//!
6//! # Migration from C++ FFI
7//!
8//! This module previously used C++ SIMD kernels via FFI. It has been migrated
9//! to pure Rust implementations in the `simd` module, providing:
10//! - Unified toolchain (no C++ compiler needed)
11//! - Cross-function inlining
12//! - Better error messages and debugging
13//! - `cargo miri` support for undefined behavior detection
14
15use std::sync::OnceLock;
16
17// Re-export from the simd module for backwards compatibility
18pub use crate::simd::dispatch::{CpuFeatures, SimdLevel};
19
20/// Global CPU features, detected once at first use
21static CPU_FEATURES: OnceLock<CpuFeatures> = OnceLock::new();
22
23/// Get detected CPU features (cached)
24pub fn cpu_features() -> &'static CpuFeatures {
25    CPU_FEATURES.get_or_init(CpuFeatures::detect)
26}
27
28/// Get best available SIMD level
29pub fn simd_level() -> SimdLevel {
30    cpu_features().best_level()
31}
32
33// ============================================================================
34// BPS Scan Dispatcher
35// ============================================================================
36
37/// BPS scan dispatcher - uses pure Rust SIMD implementations.
38pub struct BpsScanDispatcher;
39
40impl BpsScanDispatcher {
41    /// Scan BPS data and compute L1 distances.
42    ///
43    /// Data layout (SoA): bps_data[block * n_vec + vec]
44    ///
45    /// Uses pure Rust SIMD implementations for optimal performance.
46    pub fn scan(
47        bps: &[u8],
48        n_vec: usize,
49        n_blocks: usize,
50        _proj: usize, // Legacy parameter, kept for API compat
51        query: &[u8],
52        out: &mut [u16],
53    ) {
54        crate::simd::bps_scan::bps_scan(bps, n_vec, n_blocks, query, out);
55    }
56
57    /// New interface - returns u32 distances.
58    pub fn scan_u32(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u32]) {
59        crate::simd::bps_scan::bps_scan_u32(bps, n_vec, n_blocks, query, out);
60    }
61
62    /// Rust fallback implementation (for testing).
63    #[allow(dead_code)]
64    pub(crate) fn scan_fallback(
65        bps: &[u8],
66        n_vec: usize,
67        n_blocks: usize,
68        proj: usize,
69        query: &[u8],
70        out: &mut [u16],
71    ) {
72        let slots = n_blocks * proj;
73
74        // Zero output
75        for d in out.iter_mut().take(n_vec) {
76            *d = 0;
77        }
78
79        for slot in 0..slots {
80            let q = query[slot] as i16;
81            let base = slot * n_vec;
82
83            for vec_id in 0..n_vec {
84                let v = bps[base + vec_id] as i16;
85                let diff = (q - v).abs() as u16;
86                out[vec_id] = out[vec_id].saturating_add(diff);
87            }
88        }
89    }
90
91    /// Rust fallback implementation for u32 output (for testing).
92    #[allow(dead_code)]
93    pub(crate) fn scan_fallback_u32(
94        bps: &[u8],
95        n_vec: usize,
96        n_blocks: usize,
97        query: &[u8],
98        out: &mut [u32],
99    ) {
100        // Zero output
101        for d in out.iter_mut().take(n_vec) {
102            *d = 0;
103        }
104
105        for block in 0..n_blocks {
106            let q = query[block];
107            let base = block * n_vec;
108
109            for vec_id in 0..n_vec {
110                let v = bps[base + vec_id];
111                let diff = if q > v { q - v } else { v - q };
112                out[vec_id] += diff as u32;
113            }
114        }
115    }
116}
117
118// ============================================================================
119// int8 Dot Product Dispatcher
120// ============================================================================
121
122/// int8 dot product dispatcher - uses pure Rust SIMD implementations.
123pub struct DotI8Dispatcher;
124
125impl DotI8Dispatcher {
126    /// Compute single dot product.
127    pub fn dot(a: &[i8], b: &[i8]) -> i32 {
128        crate::simd::dot_i8::dot_i8(a, b)
129    }
130
131    /// Compute int8 dot products for candidate reranking.
132    pub fn compute(
133        query: &[i8],
134        vectors: &[i8],
135        cand_ids: &[u32],
136        dim: usize,
137        out_scores: &mut [i32],
138    ) {
139        crate::simd::dot_i8::dot_i8_indexed(query, vectors, cand_ids, dim, out_scores);
140    }
141
142    /// Compute with dequantization for contiguous vectors.
143    pub fn compute_batch_contiguous(
144        query: &[i8],
145        vectors: &[i8],
146        scales: &[f32],
147        dim: usize,
148        out_scores: &mut [f32],
149    ) {
150        crate::simd::dot_i8::dot_i8_batch(query, vectors, scales, dim, out_scores);
151    }
152
153    /// Legacy interface - compute with dequantization for indexed access.
154    pub fn compute_batch(
155        query: &[i8],
156        vectors: &[i8],
157        cand_ids: &[u32],
158        dim: usize,
159        query_scale: f32,
160        vec_scales: &[f32],
161        out_scores: &mut [f32],
162    ) {
163        let n_cand = cand_ids.len();
164        assert!(query.len() >= dim);
165        assert!(out_scores.len() >= n_cand);
166
167        // Compute int32 scores first
168        let mut int_scores = vec![0i32; n_cand];
169        Self::compute(query, vectors, cand_ids, dim, &mut int_scores);
170
171        // Dequantize
172        let denom = 127.0 * 127.0;
173        for (i, &cand_id) in cand_ids.iter().enumerate() {
174            let scale = query_scale * vec_scales[cand_id as usize] / denom;
175            out_scores[i] = int_scores[i] as f32 * scale;
176        }
177    }
178
179    /// Rust fallback for single dot (for testing).
180    #[allow(dead_code)]
181    pub(crate) fn dot_fallback(a: &[i8], b: &[i8]) -> i32 {
182        a.iter()
183            .zip(b.iter())
184            .map(|(&x, &y)| (x as i32) * (y as i32))
185            .sum()
186    }
187
188    /// Rust fallback implementation (for testing).
189    #[allow(dead_code)]
190    pub(crate) fn compute_fallback(
191        query: &[i8],
192        vectors: &[i8],
193        cand_ids: &[u32],
194        dim: usize,
195        out_scores: &mut [i32],
196    ) {
197        for (i, &cand_id) in cand_ids.iter().enumerate() {
198            let offset = cand_id as usize * dim;
199            let vec = &vectors[offset..offset + dim];
200            out_scores[i] = Self::dot_fallback(&query[..dim], vec);
201        }
202    }
203
204    /// Rust fallback for batch contiguous (for testing).
205    #[allow(dead_code)]
206    pub(crate) fn compute_batch_fallback(
207        query: &[i8],
208        vectors: &[i8],
209        scales: &[f32],
210        dim: usize,
211        out_scores: &mut [f32],
212    ) {
213        for (i, &scale) in scales.iter().enumerate() {
214            let offset = i * dim;
215            let vec = &vectors[offset..offset + dim];
216            let int_score = Self::dot_fallback(&query[..dim], vec);
217            out_scores[i] = int_score as f32 * scale;
218        }
219    }
220}
221
222// ============================================================================
223// Visibility Check Dispatcher
224// ============================================================================
225
226/// SIMD-accelerated batch visibility checking for MVCC snapshots.
227///
228/// Checks which rows are visible to a given snapshot timestamp.
229/// A row is visible if: commit_ts != 0 && commit_ts < snapshot_ts
230/// Or if the row belongs to the current transaction (txn_id match).
231pub struct VisibilityDispatcher;
232
233impl VisibilityDispatcher {
234    /// Check visibility for a batch of rows based on commit timestamps.
235    ///
236    /// # Arguments
237    /// * `commit_timestamps` - Array of commit timestamps (0 = uncommitted)
238    /// * `snapshot_ts` - The snapshot timestamp for visibility check
239    /// * `visible_mask` - Output: 1 if visible, 0 if not visible
240    ///
241    /// # Panics
242    /// Panics if visible_mask length doesn't match commit_timestamps length.
243    pub fn check_batch(commit_timestamps: &[u64], snapshot_ts: u64, visible_mask: &mut [u8]) {
244        crate::simd::visibility::visibility_check(commit_timestamps, snapshot_ts, visible_mask);
245    }
246
247    /// Check visibility with transaction ID awareness (for self-visibility).
248    ///
249    /// A row is visible if:
250    /// - (commit_ts != 0 && commit_ts < snapshot_ts), OR
251    /// - txn_id == current_txn_id (self-visibility)
252    ///
253    /// # Arguments
254    /// * `commit_timestamps` - Array of commit timestamps (0 = uncommitted)
255    /// * `txn_ids` - Array of transaction IDs that wrote each row
256    /// * `snapshot_ts` - The snapshot timestamp for visibility check
257    /// * `current_txn_id` - The current transaction's ID
258    /// * `visible_mask` - Output: 1 if visible, 0 if not visible
259    pub fn check_batch_with_txn(
260        commit_timestamps: &[u64],
261        txn_ids: &[u64],
262        snapshot_ts: u64,
263        current_txn_id: u64,
264        visible_mask: &mut [u8],
265    ) {
266        crate::simd::visibility::visibility_check_with_txn(
267            commit_timestamps,
268            txn_ids,
269            snapshot_ts,
270            current_txn_id,
271            visible_mask,
272        );
273    }
274
275    /// Rust fallback implementation for batch visibility check (for testing).
276    #[allow(dead_code)]
277    pub(crate) fn check_batch_fallback(
278        commit_timestamps: &[u64],
279        snapshot_ts: u64,
280        visible_mask: &mut [u8],
281    ) {
282        for (i, &commit_ts) in commit_timestamps.iter().enumerate() {
283            visible_mask[i] = if commit_ts != 0 && commit_ts < snapshot_ts {
284                1
285            } else {
286                0
287            };
288        }
289    }
290
291    /// Rust fallback implementation for batch visibility check with txn ID (for testing).
292    #[allow(dead_code)]
293    pub(crate) fn check_batch_with_txn_fallback(
294        commit_timestamps: &[u64],
295        txn_ids: &[u64],
296        snapshot_ts: u64,
297        current_txn_id: u64,
298        visible_mask: &mut [u8],
299    ) {
300        for i in 0..commit_timestamps.len() {
301            let commit_ts = commit_timestamps[i];
302            let txn_id = txn_ids[i];
303            let visible = (commit_ts != 0 && commit_ts < snapshot_ts) || txn_id == current_txn_id;
304            visible_mask[i] = if visible { 1 } else { 0 };
305        }
306    }
307}
308
309// ============================================================================
310// Utility Functions
311// ============================================================================
312
313/// Check if SIMD kernels are available (runtime detection).
314pub fn simd_available() -> bool {
315    cpu_features().has_simd()
316}
317
318/// Get dispatch info for debugging (runtime detection).
319pub fn dispatch_info() -> String {
320    crate::simd::dispatch::dispatch_info()
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_bps_scan_fallback() {
329        let n_vec = 100;
330        let n_blocks = 4;
331        let proj = 1;
332
333        // Create test data
334        let mut bps = vec![0u8; n_blocks * proj * n_vec];
335        for i in 0..n_vec {
336            for b in 0..n_blocks {
337                bps[b * n_vec + i] = (i % 256) as u8;
338            }
339        }
340
341        let query = vec![128u8; n_blocks * proj];
342        let mut out = vec![0u16; n_vec];
343
344        BpsScanDispatcher::scan_fallback(&bps, n_vec, n_blocks, proj, &query, &mut out);
345
346        // Check results make sense
347        assert!(out.iter().all(|&d| d > 0 || d == 0));
348    }
349
350    #[test]
351    fn test_bps_scan_fallback_u32() {
352        let n_vec = 100;
353        let n_blocks = 4;
354
355        // Create test data
356        let mut bps = vec![0u8; n_blocks * n_vec];
357        for i in 0..n_vec {
358            for b in 0..n_blocks {
359                bps[b * n_vec + i] = (i % 256) as u8;
360            }
361        }
362
363        let query = vec![128u8; n_blocks];
364        let mut out = vec![0u32; n_vec];
365
366        BpsScanDispatcher::scan_fallback_u32(&bps, n_vec, n_blocks, &query, &mut out);
367
368        // Check results make sense
369        for (i, &d) in out.iter().enumerate() {
370            let expected: u32 = (0..n_blocks)
371                .map(|_b| {
372                    let v = (i % 256) as u8;
373                    let q = 128u8;
374                    (if q > v { q - v } else { v - q }) as u32
375                })
376                .sum();
377            assert_eq!(d, expected);
378        }
379    }
380
381    #[test]
382    fn test_dot_i8_fallback() {
383        let dim = 64;
384        let n_vec = 10;
385
386        let query: Vec<i8> = (0..dim).map(|i| (i % 128) as i8).collect();
387        let vectors: Vec<i8> = (0..n_vec * dim)
388            .map(|i| ((i / dim) as i8).wrapping_mul(2))
389            .collect();
390        let cand_ids: Vec<u32> = (0..n_vec as u32).collect();
391        let mut out = vec![0i32; n_vec];
392
393        DotI8Dispatcher::compute_fallback(&query, &vectors, &cand_ids, dim, &mut out);
394
395        // Scores should vary
396        assert!(out.iter().any(|&s| s != out[0]));
397    }
398
399    #[test]
400    fn test_dot_single() {
401        let a: Vec<i8> = vec![1, 2, 3, 4, 5];
402        let b: Vec<i8> = vec![1, 2, 3, 4, 5];
403        let result = DotI8Dispatcher::dot_fallback(&a, &b);
404        assert_eq!(result, 1 + 4 + 9 + 16 + 25);
405    }
406
407    #[test]
408    fn test_dispatch_info() {
409        let info = dispatch_info();
410        assert!(!info.is_empty());
411        println!("Dispatch: {}", info);
412    }
413
414    /// Cross-validate SIMD dispatch vs fallback for bit-exact equivalence.
415    #[test]
416    fn test_simd_dispatch_cross_validation() {
417        // Test BPS scan equivalence
418        let n_vec = 256;
419        let n_blocks = 8;
420
421        // Generate deterministic test data
422        let bps: Vec<u8> = (0..(n_blocks * n_vec))
423            .map(|i| ((i * 17 + 13) % 256) as u8)
424            .collect();
425        let query: Vec<u8> = (0..n_blocks).map(|i| (i * 31 + 7) as u8).collect();
426
427        // Reference: fallback implementation
428        let mut ref_distances = vec![0u16; n_vec];
429        BpsScanDispatcher::scan_fallback(&bps, n_vec, n_blocks, 1, &query, &mut ref_distances);
430
431        // Dispatch: uses SIMD if available
432        let mut dispatch_distances = vec![0u16; n_vec];
433        BpsScanDispatcher::scan(&bps, n_vec, n_blocks, 1, &query, &mut dispatch_distances);
434
435        // Verify bit-exact match
436        for i in 0..n_vec {
437            assert_eq!(
438                ref_distances[i], dispatch_distances[i],
439                "BPS scan mismatch at vector {}: fallback={}, dispatch={}",
440                i, ref_distances[i], dispatch_distances[i]
441            );
442        }
443
444        // Test int8 dot product equivalence
445        let dim = 128;
446        let a: Vec<i8> = (0..dim).map(|i| ((i * 3 - 64) % 128) as i8).collect();
447        let b: Vec<i8> = (0..dim).map(|i| ((i * 7 + 32) % 128) as i8).collect();
448
449        let ref_dot = DotI8Dispatcher::dot_fallback(&a, &b);
450        let dispatch_dot = DotI8Dispatcher::dot(&a, &b);
451
452        assert_eq!(
453            ref_dot, dispatch_dot,
454            "int8 dot product mismatch: fallback={}, dispatch={}",
455            ref_dot, dispatch_dot
456        );
457    }
458
459    /// Test CPU feature detection
460    #[test]
461    fn test_cpu_features_detection() {
462        let features = cpu_features();
463        let level = simd_level();
464
465        println!("CPU Features: {:?}", features);
466        println!("SIMD Level: {:?}", level);
467        println!("Dispatch Info: {}", dispatch_info());
468
469        // On any modern x86_64, we should have at least SSE4.1
470        #[cfg(target_arch = "x86_64")]
471        {
472            // Most x86_64 CPUs have SSE4.1+
473            assert!(level >= SimdLevel::Scalar);
474        }
475
476        // On aarch64, we always have NEON
477        #[cfg(target_arch = "aarch64")]
478        {
479            assert!(features.has_neon);
480            assert!(level >= SimdLevel::Neon);
481        }
482    }
483
484    /// Test visibility check fallback
485    #[test]
486    fn test_visibility_check_basic() {
487        let commit_timestamps = vec![10, 0, 5, 15, 20, 8];
488        let snapshot_ts = 12;
489        let mut visible_mask = vec![0u8; 6];
490
491        VisibilityDispatcher::check_batch(&commit_timestamps, snapshot_ts, &mut visible_mask);
492
493        // Expected: [1, 0, 1, 0, 0, 1]
494        assert_eq!(visible_mask, vec![1, 0, 1, 0, 0, 1]);
495    }
496
497    /// Test visibility check with transaction ID
498    #[test]
499    fn test_visibility_check_with_txn() {
500        let commit_timestamps = vec![10, 0, 5, 0, 20, 8];
501        let txn_ids = vec![1, 2, 3, 99, 5, 6];
502        let snapshot_ts = 12;
503        let current_txn_id = 99;
504        let mut visible_mask = vec![0u8; 6];
505
506        VisibilityDispatcher::check_batch_with_txn(
507            &commit_timestamps,
508            &txn_ids,
509            snapshot_ts,
510            current_txn_id,
511            &mut visible_mask,
512        );
513
514        // Expected: [1, 0, 1, 1, 0, 1]
515        assert_eq!(visible_mask, vec![1, 0, 1, 1, 0, 1]);
516    }
517
518    /// Test visibility dispatcher SIMD vs fallback equivalence
519    #[test]
520    fn test_visibility_simd_equivalence() {
521        let n_rows = 1024;
522
523        // Generate test data
524        let commit_timestamps: Vec<u64> = (0..n_rows)
525            .map(|i| if i % 5 == 0 { 0 } else { (i * 7 % 100) as u64 })
526            .collect();
527        let txn_ids: Vec<u64> = (0..n_rows).map(|i| (i % 10) as u64).collect();
528        let snapshot_ts = 50;
529        let current_txn_id = 7;
530
531        // Test basic visibility
532        let mut ref_mask = vec![0u8; n_rows];
533        let mut dispatch_mask = vec![0u8; n_rows];
534
535        VisibilityDispatcher::check_batch_fallback(&commit_timestamps, snapshot_ts, &mut ref_mask);
536        VisibilityDispatcher::check_batch(&commit_timestamps, snapshot_ts, &mut dispatch_mask);
537
538        for i in 0..n_rows {
539            assert_eq!(
540                ref_mask[i], dispatch_mask[i],
541                "Visibility mismatch at row {}: fallback={}, dispatch={}",
542                i, ref_mask[i], dispatch_mask[i]
543            );
544        }
545
546        // Test with txn ID
547        let mut ref_mask_txn = vec![0u8; n_rows];
548        let mut dispatch_mask_txn = vec![0u8; n_rows];
549
550        VisibilityDispatcher::check_batch_with_txn_fallback(
551            &commit_timestamps,
552            &txn_ids,
553            snapshot_ts,
554            current_txn_id,
555            &mut ref_mask_txn,
556        );
557        VisibilityDispatcher::check_batch_with_txn(
558            &commit_timestamps,
559            &txn_ids,
560            snapshot_ts,
561            current_txn_id,
562            &mut dispatch_mask_txn,
563        );
564
565        for i in 0..n_rows {
566            assert_eq!(
567                ref_mask_txn[i], dispatch_mask_txn[i],
568                "Visibility+txn mismatch at row {}: fallback={}, dispatch={}",
569                i, ref_mask_txn[i], dispatch_mask_txn[i]
570            );
571        }
572    }
573}