velesdb_core/
simd_dispatch.rs

1//! Zero-overhead SIMD function dispatch using `OnceLock`.
2//!
3//! This module provides compile-time-like dispatch for SIMD functions
4//! by detecting CPU features once at startup and caching function pointers.
5//!
6//! # Performance
7//!
8//! - **Zero branch overhead**: Function pointer is resolved once, called directly thereafter
9//! - **No per-call checks**: Eliminates `is_x86_feature_detected!` in hot loops
10//! - **Inlinable**: Function pointers can be inlined by LLVM in some cases
11//!
12//! # EPIC-C.2: TS-SIMD-002
13
14use std::sync::OnceLock;
15
16/// Type alias for distance function pointers.
17type DistanceFn = fn(&[f32], &[f32]) -> f32;
18
19/// Type alias for binary distance function pointers (returns u32).
20type BinaryDistanceFn = fn(&[f32], &[f32]) -> u32;
21
22// =============================================================================
23// Static dispatch tables - initialized once on first use
24// =============================================================================
25
26/// Dispatched dot product function.
27static DOT_PRODUCT_FN: OnceLock<DistanceFn> = OnceLock::new();
28
29/// Dispatched euclidean distance function.
30static EUCLIDEAN_FN: OnceLock<DistanceFn> = OnceLock::new();
31
32/// Dispatched cosine similarity function.
33static COSINE_FN: OnceLock<DistanceFn> = OnceLock::new();
34
35/// Dispatched cosine similarity for normalized vectors.
36static COSINE_NORMALIZED_FN: OnceLock<DistanceFn> = OnceLock::new();
37
38/// Dispatched Hamming distance function.
39static HAMMING_FN: OnceLock<BinaryDistanceFn> = OnceLock::new();
40
41// =============================================================================
42// Feature detection and dispatch selection
43// =============================================================================
44
45/// Selects the best dot product implementation for the current CPU.
46fn select_dot_product() -> DistanceFn {
47    #[cfg(target_arch = "x86_64")]
48    {
49        if is_x86_feature_detected!("avx512f") {
50            return dot_product_avx512;
51        }
52        if is_x86_feature_detected!("avx2") {
53            return dot_product_avx2;
54        }
55    }
56    dot_product_scalar
57}
58
59/// Selects the best euclidean distance implementation for the current CPU.
60fn select_euclidean() -> DistanceFn {
61    #[cfg(target_arch = "x86_64")]
62    {
63        if is_x86_feature_detected!("avx512f") {
64            return euclidean_avx512;
65        }
66        if is_x86_feature_detected!("avx2") {
67            return euclidean_avx2;
68        }
69    }
70    euclidean_scalar
71}
72
73/// Selects the best cosine similarity implementation for the current CPU.
74fn select_cosine() -> DistanceFn {
75    #[cfg(target_arch = "x86_64")]
76    {
77        if is_x86_feature_detected!("avx512f") {
78            return cosine_avx512;
79        }
80        if is_x86_feature_detected!("avx2") {
81            return cosine_avx2;
82        }
83    }
84    cosine_scalar
85}
86
87/// Selects the best cosine similarity (normalized) implementation.
88fn select_cosine_normalized() -> DistanceFn {
89    #[cfg(target_arch = "x86_64")]
90    {
91        if is_x86_feature_detected!("avx512f") {
92            return cosine_normalized_avx512;
93        }
94        if is_x86_feature_detected!("avx2") {
95            return cosine_normalized_avx2;
96        }
97    }
98    cosine_normalized_scalar
99}
100
101/// Selects the best Hamming distance implementation.
102fn select_hamming() -> BinaryDistanceFn {
103    #[cfg(target_arch = "x86_64")]
104    {
105        if is_x86_feature_detected!("avx512vpopcntdq") {
106            return hamming_avx512_popcnt;
107        }
108        if is_x86_feature_detected!("popcnt") {
109            return hamming_popcnt;
110        }
111    }
112    hamming_scalar
113}
114
115// =============================================================================
116// Public dispatch API
117// =============================================================================
118
119/// Computes dot product using the best available SIMD implementation.
120///
121/// The implementation is selected once on first call and cached.
122///
123/// # Panics
124///
125/// Panics if vectors have different lengths.
126#[inline]
127#[must_use]
128pub fn dot_product_dispatched(a: &[f32], b: &[f32]) -> f32 {
129    let f = DOT_PRODUCT_FN.get_or_init(select_dot_product);
130    f(a, b)
131}
132
133/// Computes euclidean distance using the best available SIMD implementation.
134#[inline]
135#[must_use]
136pub fn euclidean_dispatched(a: &[f32], b: &[f32]) -> f32 {
137    let f = EUCLIDEAN_FN.get_or_init(select_euclidean);
138    f(a, b)
139}
140
141/// Computes cosine similarity using the best available SIMD implementation.
142#[inline]
143#[must_use]
144pub fn cosine_dispatched(a: &[f32], b: &[f32]) -> f32 {
145    let f = COSINE_FN.get_or_init(select_cosine);
146    f(a, b)
147}
148
149/// Computes cosine similarity for pre-normalized vectors.
150#[inline]
151#[must_use]
152pub fn cosine_normalized_dispatched(a: &[f32], b: &[f32]) -> f32 {
153    let f = COSINE_NORMALIZED_FN.get_or_init(select_cosine_normalized);
154    f(a, b)
155}
156
157/// Computes Hamming distance using the best available implementation.
158#[inline]
159#[must_use]
160pub fn hamming_dispatched(a: &[f32], b: &[f32]) -> u32 {
161    let f = HAMMING_FN.get_or_init(select_hamming);
162    f(a, b)
163}
164
165/// Returns information about which SIMD features are available.
166#[must_use]
167pub fn simd_features_info() -> SimdFeatures {
168    SimdFeatures::detect()
169}
170
171/// Information about available SIMD features.
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173#[allow(clippy::struct_excessive_bools)]
174pub struct SimdFeatures {
175    /// AVX-512 foundation instructions available.
176    pub avx512f: bool,
177    /// AVX-512 VPOPCNTDQ (population count) available.
178    pub avx512_popcnt: bool,
179    /// AVX2 instructions available.
180    pub avx2: bool,
181    /// POPCNT instruction available.
182    pub popcnt: bool,
183}
184
185impl SimdFeatures {
186    /// Detects available SIMD features on the current CPU.
187    #[must_use]
188    pub fn detect() -> Self {
189        #[cfg(target_arch = "x86_64")]
190        {
191            Self {
192                avx512f: is_x86_feature_detected!("avx512f"),
193                avx512_popcnt: is_x86_feature_detected!("avx512vpopcntdq"),
194                avx2: is_x86_feature_detected!("avx2"),
195                popcnt: is_x86_feature_detected!("popcnt"),
196            }
197        }
198
199        #[cfg(not(target_arch = "x86_64"))]
200        {
201            Self {
202                avx512f: false,
203                avx512_popcnt: false,
204                avx2: false,
205                popcnt: false,
206            }
207        }
208    }
209
210    /// Returns the best available instruction set name.
211    #[must_use]
212    pub const fn best_instruction_set(&self) -> &'static str {
213        if self.avx512f {
214            "AVX-512"
215        } else if self.avx2 {
216            "AVX2"
217        } else {
218            "Scalar"
219        }
220    }
221}
222
223// =============================================================================
224// Implementation functions - delegating to simd_avx512 and simd_explicit
225// =============================================================================
226
227// --- Dot Product implementations ---
228
229fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
230    assert_eq!(a.len(), b.len(), "Vector length mismatch");
231    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
232}
233
234#[cfg(target_arch = "x86_64")]
235fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
236    crate::simd_explicit::dot_product_simd(a, b)
237}
238
239#[cfg(target_arch = "x86_64")]
240fn dot_product_avx512(a: &[f32], b: &[f32]) -> f32 {
241    crate::simd_avx512::dot_product_auto(a, b)
242}
243
244// --- Euclidean implementations ---
245
246fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
247    assert_eq!(a.len(), b.len(), "Vector length mismatch");
248    a.iter()
249        .zip(b.iter())
250        .map(|(x, y)| {
251            let d = x - y;
252            d * d
253        })
254        .sum::<f32>()
255        .sqrt()
256}
257
258#[cfg(target_arch = "x86_64")]
259fn euclidean_avx2(a: &[f32], b: &[f32]) -> f32 {
260    crate::simd_explicit::euclidean_distance_simd(a, b)
261}
262
263#[cfg(target_arch = "x86_64")]
264fn euclidean_avx512(a: &[f32], b: &[f32]) -> f32 {
265    crate::simd_avx512::euclidean_auto(a, b)
266}
267
268// --- Cosine implementations ---
269
270fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
271    assert_eq!(a.len(), b.len(), "Vector length mismatch");
272    let mut dot = 0.0f32;
273    let mut norm_a = 0.0f32;
274    let mut norm_b = 0.0f32;
275
276    for (x, y) in a.iter().zip(b.iter()) {
277        dot += x * y;
278        norm_a += x * x;
279        norm_b += y * y;
280    }
281
282    let denom = (norm_a * norm_b).sqrt();
283    if denom > 0.0 {
284        dot / denom
285    } else {
286        0.0
287    }
288}
289
290#[cfg(target_arch = "x86_64")]
291fn cosine_avx2(a: &[f32], b: &[f32]) -> f32 {
292    crate::simd_explicit::cosine_similarity_simd(a, b)
293}
294
295#[cfg(target_arch = "x86_64")]
296fn cosine_avx512(a: &[f32], b: &[f32]) -> f32 {
297    crate::simd_avx512::cosine_similarity_auto(a, b)
298}
299
300// --- Cosine Normalized implementations ---
301
302fn cosine_normalized_scalar(a: &[f32], b: &[f32]) -> f32 {
303    // For normalized vectors, cosine = dot product
304    dot_product_scalar(a, b)
305}
306
307#[cfg(target_arch = "x86_64")]
308fn cosine_normalized_avx2(a: &[f32], b: &[f32]) -> f32 {
309    crate::simd_explicit::dot_product_simd(a, b)
310}
311
312#[cfg(target_arch = "x86_64")]
313fn cosine_normalized_avx512(a: &[f32], b: &[f32]) -> f32 {
314    crate::simd_avx512::dot_product_auto(a, b)
315}
316
317// --- Hamming implementations ---
318
319fn hamming_scalar(a: &[f32], b: &[f32]) -> u32 {
320    assert_eq!(a.len(), b.len(), "Vector length mismatch");
321    #[allow(clippy::cast_possible_truncation)]
322    let count = a
323        .iter()
324        .zip(b.iter())
325        .filter(|(&x, &y)| (x > 0.5) != (y > 0.5))
326        .count() as u32;
327    count
328}
329
330#[cfg(target_arch = "x86_64")]
331#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
332fn hamming_popcnt(a: &[f32], b: &[f32]) -> u32 {
333    // Use existing implementation - safe cast as hamming distance is always positive integer
334    crate::simd_explicit::hamming_distance_simd(a, b) as u32
335}
336
337#[cfg(target_arch = "x86_64")]
338fn hamming_avx512_popcnt(a: &[f32], b: &[f32]) -> u32 {
339    // For now, delegate to regular popcnt
340    // TODO: Implement true AVX-512 VPOPCNTDQ when available
341    hamming_popcnt(a, b)
342}
343
344// =============================================================================
345// Prefetch constants - EPIC-C.1
346// =============================================================================
347
348/// Cache line size in bytes (standard for modern x86/ARM CPUs).
349pub const CACHE_LINE_SIZE: usize = 64;
350
351/// Prefetch distance for 768-dimensional vectors (3072 bytes).
352/// Calculated at compile time: `768 * 4 / 64 = 48` cache lines.
353pub const PREFETCH_DISTANCE_768D: usize = 768 * std::mem::size_of::<f32>() / CACHE_LINE_SIZE;
354
355/// Prefetch distance for 384-dimensional vectors.
356pub const PREFETCH_DISTANCE_384D: usize = 384 * std::mem::size_of::<f32>() / CACHE_LINE_SIZE;
357
358/// Prefetch distance for 1536-dimensional vectors.
359pub const PREFETCH_DISTANCE_1536D: usize = 1536 * std::mem::size_of::<f32>() / CACHE_LINE_SIZE;
360
361/// Calculates prefetch distance for a given dimension at compile time.
362#[inline]
363#[must_use]
364pub const fn prefetch_distance(dimension: usize) -> usize {
365    (dimension * std::mem::size_of::<f32>()) / CACHE_LINE_SIZE
366}
367
368// =============================================================================
369// TDD TESTS
370// =============================================================================
371
372#[cfg(test)]
373#[allow(
374    clippy::cast_precision_loss,
375    clippy::uninlined_format_args,
376    clippy::float_cmp
377)]
378mod tests {
379    use super::*;
380
381    // -------------------------------------------------------------------------
382    // Dispatch correctness tests
383    // -------------------------------------------------------------------------
384
385    #[test]
386    fn test_dot_product_dispatched_correctness() {
387        // Arrange
388        let a = vec![1.0f32, 2.0, 3.0, 4.0];
389        let b = vec![5.0f32, 6.0, 7.0, 8.0];
390
391        // Act
392        let result = dot_product_dispatched(&a, &b);
393
394        // Assert - 1*5 + 2*6 + 3*7 + 4*8 = 5 + 12 + 21 + 32 = 70
395        assert!((result - 70.0).abs() < 1e-5);
396    }
397
398    #[test]
399    fn test_euclidean_dispatched_correctness() {
400        // Arrange
401        let a = vec![0.0f32, 0.0, 0.0];
402        let b = vec![3.0f32, 4.0, 0.0];
403
404        // Act
405        let result = euclidean_dispatched(&a, &b);
406
407        // Assert - sqrt(9 + 16) = 5
408        assert!((result - 5.0).abs() < 1e-5);
409    }
410
411    #[test]
412    fn test_cosine_dispatched_correctness() {
413        // Arrange - same vector should have cosine = 1.0
414        let a = vec![1.0f32, 2.0, 3.0];
415        let b = vec![1.0f32, 2.0, 3.0];
416
417        // Act
418        let result = cosine_dispatched(&a, &b);
419
420        // Assert
421        assert!((result - 1.0).abs() < 1e-5);
422    }
423
424    #[test]
425    fn test_cosine_dispatched_orthogonal() {
426        // Arrange - orthogonal vectors should have cosine = 0
427        let a = vec![1.0f32, 0.0, 0.0];
428        let b = vec![0.0f32, 1.0, 0.0];
429
430        // Act
431        let result = cosine_dispatched(&a, &b);
432
433        // Assert
434        assert!(result.abs() < 1e-5);
435    }
436
437    #[test]
438    fn test_cosine_normalized_dispatched() {
439        // Arrange - pre-normalized vectors
440        let a = vec![1.0f32, 0.0];
441        let b = vec![0.707f32, 0.707]; // ~45 degrees
442
443        // Act
444        let result = cosine_normalized_dispatched(&a, &b);
445
446        // Assert - cos(45°) ≈ 0.707
447        assert!((result - 0.707).abs() < 0.01);
448    }
449
450    #[test]
451    fn test_hamming_dispatched_correctness() {
452        // Arrange - binary vectors encoded as f32
453        let a = vec![1.0f32, 0.0, 1.0, 0.0]; // bits: 1010
454        let b = vec![1.0f32, 1.0, 0.0, 0.0]; // bits: 1100
455
456        // Act
457        let result = hamming_dispatched(&a, &b);
458
459        // Assert - differs in positions 1 and 2
460        assert_eq!(result, 2);
461    }
462
463    // -------------------------------------------------------------------------
464    // Large vector tests (768D like real embeddings)
465    // -------------------------------------------------------------------------
466
467    #[test]
468    fn test_dot_product_dispatched_768d() {
469        // Arrange
470        let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.001).collect();
471        let b: Vec<f32> = (0..768).map(|i| ((768 - i) as f32) * 0.001).collect();
472
473        // Act
474        let result = dot_product_dispatched(&a, &b);
475
476        // Assert - just verify it doesn't panic and returns reasonable value
477        assert!(result.is_finite());
478        assert!(result > 0.0);
479    }
480
481    #[test]
482    fn test_euclidean_dispatched_768d() {
483        // Arrange
484        let a: Vec<f32> = vec![0.0; 768];
485        let b: Vec<f32> = vec![1.0; 768];
486
487        // Act
488        let result = euclidean_dispatched(&a, &b);
489
490        // Assert - sqrt(768 * 1) ≈ 27.71
491        assert!((result - 768.0_f32.sqrt()).abs() < 0.01);
492    }
493
494    #[test]
495    fn test_cosine_dispatched_768d() {
496        // Arrange
497        let a: Vec<f32> = (0..768).map(|i| (i as f32).sin()).collect();
498        let b = a.clone();
499
500        // Act
501        let result = cosine_dispatched(&a, &b);
502
503        // Assert - same vector = 1.0
504        assert!((result - 1.0).abs() < 1e-4);
505    }
506
507    // -------------------------------------------------------------------------
508    // SIMD features detection tests
509    // -------------------------------------------------------------------------
510
511    #[test]
512    fn test_simd_features_detect() {
513        // Act
514        let features = SimdFeatures::detect();
515
516        // Assert - just verify it doesn't panic
517        let _name = features.best_instruction_set();
518        println!("SIMD features: {:?}", features);
519        println!("Best instruction set: {}", features.best_instruction_set());
520    }
521
522    #[test]
523    fn test_simd_features_info() {
524        // Act
525        let features = simd_features_info();
526
527        // Assert - returns valid struct
528        assert!(!features.best_instruction_set().is_empty());
529    }
530
531    // -------------------------------------------------------------------------
532    // Prefetch constant tests
533    // -------------------------------------------------------------------------
534
535    #[test]
536    fn test_prefetch_distance_768d() {
537        // 768 * 4 bytes / 64 bytes = 48 cache lines
538        assert_eq!(PREFETCH_DISTANCE_768D, 48);
539    }
540
541    #[test]
542    fn test_prefetch_distance_384d() {
543        // 384 * 4 bytes / 64 bytes = 24 cache lines
544        assert_eq!(PREFETCH_DISTANCE_384D, 24);
545    }
546
547    #[test]
548    fn test_prefetch_distance_1536d() {
549        // 1536 * 4 bytes / 64 bytes = 96 cache lines
550        assert_eq!(PREFETCH_DISTANCE_1536D, 96);
551    }
552
553    #[test]
554    fn test_prefetch_distance_function() {
555        assert_eq!(prefetch_distance(768), 48);
556        assert_eq!(prefetch_distance(384), 24);
557        assert_eq!(prefetch_distance(128), 8);
558    }
559
560    // -------------------------------------------------------------------------
561    // OnceLock initialization tests
562    // -------------------------------------------------------------------------
563
564    #[test]
565    fn test_dispatch_initialized_once() {
566        // Multiple calls should use cached function pointer
567        let a = vec![1.0f32; 100];
568        let b = vec![2.0f32; 100];
569
570        // First call initializes
571        let r1 = dot_product_dispatched(&a, &b);
572
573        // Second call uses cached pointer
574        let r2 = dot_product_dispatched(&a, &b);
575
576        // Results should be identical
577        assert_eq!(r1, r2);
578    }
579
580    #[test]
581    fn test_dispatch_thread_safe() {
582        use std::sync::Arc;
583        use std::thread;
584
585        // Arrange
586        let a = Arc::new(vec![1.0f32; 768]);
587        let b = Arc::new(vec![2.0f32; 768]);
588
589        // Act - multiple threads calling dispatched functions
590        let handles: Vec<_> = (0..4)
591            .map(|_| {
592                let a = Arc::clone(&a);
593                let b = Arc::clone(&b);
594                thread::spawn(move || {
595                    for _ in 0..100 {
596                        let _ = dot_product_dispatched(&a, &b);
597                        let _ = cosine_dispatched(&a, &b);
598                        let _ = euclidean_dispatched(&a, &b);
599                    }
600                })
601            })
602            .collect();
603
604        // Assert - no panics
605        for h in handles {
606            h.join().expect("Thread should not panic");
607        }
608    }
609
610    // -------------------------------------------------------------------------
611    // Edge case tests
612    // -------------------------------------------------------------------------
613
614    #[test]
615    #[should_panic(expected = "dimensions must match")]
616    fn test_dot_product_dispatched_length_mismatch() {
617        let a = vec![1.0f32, 2.0];
618        let b = vec![1.0f32, 2.0, 3.0];
619        let _ = dot_product_dispatched(&a, &b);
620    }
621
622    #[test]
623    fn test_empty_vectors() {
624        let a: Vec<f32> = vec![];
625        let b: Vec<f32> = vec![];
626
627        // Should not panic, returns 0
628        assert_eq!(dot_product_dispatched(&a, &b), 0.0);
629    }
630
631    #[test]
632    fn test_single_element() {
633        let a = vec![3.0f32];
634        let b = vec![4.0f32];
635
636        assert_eq!(dot_product_dispatched(&a, &b), 12.0);
637    }
638}