Skip to main content

velesdb_core/
simd_dispatch.rs

1#![allow(
2    clippy::cast_precision_loss,
3    clippy::cast_possible_truncation,
4    clippy::cast_sign_loss,
5    clippy::float_cmp
6)]
7//! Zero-overhead SIMD function dispatch.
8//!
9//! This module provides a thin wrapper around `simd_native` functions,
10//! offering a stable public API while `simd_native` handles the
11//! architecture-specific SIMD implementations internally.
12//!
13//! # EPIC-C.2: TS-SIMD-002
14
15// SAFETY: Numeric casts in SIMD dispatch are intentional:
16// - usize->u32 for Hamming distance: vector dimensions bounded by implementation
17// - Maximum dimension is 65536, result fits in u32
18
19// =============================================================================
20// Public dispatch API - Direct calls to simd_native
21// =============================================================================
22
23/// Compute dot product with automatic SIMD dispatch.
24#[inline]
25#[must_use]
26pub fn dot_product_dispatched(a: &[f32], b: &[f32]) -> f32 {
27    crate::simd_native::dot_product_native(a, b)
28}
29
30/// Compute Euclidean distance with automatic SIMD dispatch.
31#[inline]
32#[must_use]
33pub fn euclidean_dispatched(a: &[f32], b: &[f32]) -> f32 {
34    crate::simd_native::euclidean_native(a, b)
35}
36
37/// Compute cosine similarity with automatic SIMD dispatch.
38#[inline]
39#[must_use]
40pub fn cosine_dispatched(a: &[f32], b: &[f32]) -> f32 {
41    crate::simd_native::cosine_similarity_native(a, b)
42}
43
44/// Compute cosine similarity for pre-normalized vectors.
45#[inline]
46#[must_use]
47pub fn cosine_normalized_dispatched(a: &[f32], b: &[f32]) -> f32 {
48    crate::simd_native::cosine_normalized_native(a, b)
49}
50
51/// Compute Hamming distance with automatic SIMD dispatch.
52#[inline]
53#[must_use]
54pub fn hamming_dispatched(a: &[f32], b: &[f32]) -> u32 {
55    #[allow(clippy::cast_sign_loss)]
56    // SAFETY: hamming_distance_native returns count of differing bits (non-negative),
57    // and vector dimensions are bounded by u32::MAX, so result always fits in u32
58    {
59        crate::simd_native::hamming_distance_native(a, b) as u32
60    }
61}
62
63/// Returns information about which SIMD features are available.
64#[must_use]
65pub fn simd_features_info() -> SimdFeatures {
66    SimdFeatures::detect()
67}
68
69/// Information about available SIMD features.
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71#[allow(clippy::struct_excessive_bools)]
72pub struct SimdFeatures {
73    /// AVX-512 foundation instructions available.
74    pub avx512f: bool,
75    /// AVX-512 VPOPCNTDQ (population count) available.
76    pub avx512_popcnt: bool,
77    /// AVX2 instructions available.
78    pub avx2: bool,
79    /// POPCNT instruction available.
80    pub popcnt: bool,
81}
82
83impl SimdFeatures {
84    /// Detects available SIMD features on the current CPU.
85    #[must_use]
86    pub fn detect() -> Self {
87        #[cfg(target_arch = "x86_64")]
88        {
89            Self {
90                avx512f: is_x86_feature_detected!("avx512f"),
91                avx512_popcnt: is_x86_feature_detected!("avx512vpopcntdq"),
92                avx2: is_x86_feature_detected!("avx2"),
93                popcnt: is_x86_feature_detected!("popcnt"),
94            }
95        }
96
97        #[cfg(not(target_arch = "x86_64"))]
98        {
99            Self {
100                avx512f: false,
101                avx512_popcnt: false,
102                avx2: false,
103                popcnt: false,
104            }
105        }
106    }
107
108    /// Returns the best available instruction set name.
109    #[must_use]
110    pub const fn best_instruction_set(&self) -> &'static str {
111        if self.avx512f {
112            "AVX-512"
113        } else if self.avx2 {
114            "AVX2"
115        } else {
116            "Scalar"
117        }
118    }
119}
120
121// =============================================================================
122// Prefetch constants - EPIC-C.1
123// =============================================================================
124
125// Scalar implementations for tests
126#[cfg(test)]
127fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
128    assert_eq!(a.len(), b.len(), "Vector length mismatch");
129    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
130}
131
132#[cfg(test)]
133fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
134    assert_eq!(a.len(), b.len(), "Vector length mismatch");
135    a.iter()
136        .zip(b.iter())
137        .map(|(x, y)| {
138            let d = x - y;
139            d * d
140        })
141        .sum::<f32>()
142        .sqrt()
143}
144
145#[cfg(test)]
146fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
147    assert_eq!(a.len(), b.len(), "Vector length mismatch");
148    let mut dot = 0.0f32;
149    let mut norm_a = 0.0f32;
150    let mut norm_b = 0.0f32;
151
152    for (x, y) in a.iter().zip(b.iter()) {
153        dot += x * y;
154        norm_a += x * x;
155        norm_b += y * y;
156    }
157
158    let denom = (norm_a * norm_b).sqrt();
159    if denom > 0.0 {
160        dot / denom
161    } else {
162        0.0
163    }
164}
165
166#[cfg(test)]
167fn hamming_scalar(a: &[f32], b: &[f32]) -> u32 {
168    assert_eq!(a.len(), b.len(), "Vector length mismatch");
169    #[allow(clippy::cast_possible_truncation)]
170    let count = a
171        .iter()
172        .zip(b.iter())
173        .filter(|(&x, &y)| (x > 0.5) != (y > 0.5))
174        .count() as u32;
175    count
176}
177
178#[cfg(test)]
179fn cosine_normalized_scalar(a: &[f32], b: &[f32]) -> f32 {
180    // For normalized vectors, cosine = dot product
181    dot_product_scalar(a, b)
182}
183
184/// Cache line size in bytes (standard for modern x86/ARM CPUs).
185pub const CACHE_LINE_SIZE: usize = 64;
186
187/// Prefetch distance for 768-dimensional vectors (3072 bytes).
188/// Calculated at compile time: `768 * 4 / 64 = 48` cache lines.
189pub const PREFETCH_DISTANCE_768D: usize = 768 * std::mem::size_of::<f32>() / CACHE_LINE_SIZE;
190
191/// Prefetch distance for 384-dimensional vectors.
192pub const PREFETCH_DISTANCE_384D: usize = 384 * std::mem::size_of::<f32>() / CACHE_LINE_SIZE;
193
194/// Prefetch distance for 1536-dimensional vectors.
195pub const PREFETCH_DISTANCE_1536D: usize = 1536 * std::mem::size_of::<f32>() / CACHE_LINE_SIZE;
196
197/// Calculates prefetch distance for a given dimension at compile time.
198#[inline]
199#[must_use]
200pub const fn prefetch_distance(dimension: usize) -> usize {
201    (dimension * std::mem::size_of::<f32>()) / CACHE_LINE_SIZE
202}
203
204// =============================================================================
205// Tests
206// =============================================================================
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    // =========================================================================
213    // Dispatched Function Tests
214    // =========================================================================
215
216    #[test]
217    fn test_dot_product_dispatched_basic() {
218        let a = vec![1.0, 2.0, 3.0, 4.0];
219        let b = vec![5.0, 6.0, 7.0, 8.0];
220        let result = dot_product_dispatched(&a, &b);
221        // 1*5 + 2*6 + 3*7 + 4*8 = 70
222        assert!((result - 70.0).abs() < 1e-5);
223    }
224
225    #[test]
226    fn test_dot_product_dispatched_large() {
227        let a: Vec<f32> = (0..768).map(|i| (i as f32 * 0.001).sin()).collect();
228        let b: Vec<f32> = (0..768).map(|i| (i as f32 * 0.001).cos()).collect();
229        let result = dot_product_dispatched(&a, &b);
230        assert!(result.is_finite());
231    }
232
233    #[test]
234    fn test_euclidean_dispatched_basic() {
235        let a = vec![0.0, 0.0];
236        let b = vec![3.0, 4.0];
237        let result = euclidean_dispatched(&a, &b);
238        assert!((result - 5.0).abs() < 1e-5);
239    }
240
241    #[test]
242    fn test_euclidean_dispatched_identical() {
243        let a: Vec<f32> = vec![1.0; 64];
244        let result = euclidean_dispatched(&a, &a);
245        assert!(result.abs() < 1e-6);
246    }
247
248    #[test]
249    fn test_cosine_dispatched_identical() {
250        let a: Vec<f32> = vec![1.0; 32];
251        let result = cosine_dispatched(&a, &a);
252        assert!((result - 1.0).abs() < 1e-5);
253    }
254
255    #[test]
256    fn test_cosine_dispatched_orthogonal() {
257        let mut a = vec![0.0; 32];
258        let mut b = vec![0.0; 32];
259        a[0] = 1.0;
260        b[1] = 1.0;
261        let result = cosine_dispatched(&a, &b);
262        assert!(result.abs() < 1e-5);
263    }
264
265    #[test]
266    fn test_cosine_dispatched_opposite() {
267        let a: Vec<f32> = vec![1.0; 16];
268        let b: Vec<f32> = vec![-1.0; 16];
269        let result = cosine_dispatched(&a, &b);
270        assert!((result - (-1.0)).abs() < 1e-5);
271    }
272
273    #[test]
274    fn test_cosine_normalized_dispatched() {
275        // Pre-normalized unit vectors
276        let norm = (32.0_f32).sqrt();
277        let a: Vec<f32> = vec![1.0 / norm; 32];
278        let result = cosine_normalized_dispatched(&a, &a);
279        assert!((result - 1.0).abs() < 1e-4);
280    }
281
282    #[test]
283    fn test_hamming_dispatched_identical() {
284        let a: Vec<f32> = vec![1.0; 32];
285        let result = hamming_dispatched(&a, &a);
286        assert_eq!(result, 0);
287    }
288
289    #[test]
290    fn test_hamming_dispatched_different() {
291        let a: Vec<f32> = vec![1.0; 32]; // All above 0.5
292        let b: Vec<f32> = vec![0.0; 32]; // All below 0.5
293        let result = hamming_dispatched(&a, &b);
294        assert_eq!(result, 32);
295    }
296
297    #[test]
298    fn test_hamming_dispatched_half() {
299        let a = vec![1.0; 32];
300        let mut b = vec![1.0; 32];
301        // Make half different
302        for item in b.iter_mut().take(16) {
303            *item = 0.0;
304        }
305        let result = hamming_dispatched(&a, &b);
306        assert_eq!(result, 16);
307    }
308
309    // =========================================================================
310    // SimdFeatures Tests
311    // =========================================================================
312
313    #[test]
314    fn test_simd_features_detect() {
315        let features = SimdFeatures::detect();
316        // Just verify detection doesn't panic
317        let _ = features.avx512f;
318        let _ = features.avx2;
319        let _ = features.popcnt;
320    }
321
322    #[test]
323    fn test_simd_features_info() {
324        let features = simd_features_info();
325        // Verify struct fields are accessible
326        let _ = features.avx512f;
327    }
328
329    #[test]
330    fn test_simd_features_best_instruction_set() {
331        let features = SimdFeatures::detect();
332        let best = features.best_instruction_set();
333        assert!(
334            best == "AVX-512" || best == "AVX2" || best == "Scalar",
335            "Unexpected instruction set: {best}"
336        );
337    }
338
339    #[test]
340    fn test_simd_features_debug() {
341        let features = SimdFeatures::detect();
342        let debug_str = format!("{:?}", features);
343        assert!(debug_str.contains("SimdFeatures"));
344    }
345
346    #[test]
347    fn test_simd_features_clone() {
348        let features = SimdFeatures::detect();
349        let cloned = features;
350        assert_eq!(features, cloned);
351    }
352
353    // =========================================================================
354    // Scalar Fallback Tests
355    // =========================================================================
356
357    #[test]
358    fn test_dot_product_scalar() {
359        let a = vec![1.0, 2.0, 3.0];
360        let b = vec![4.0, 5.0, 6.0];
361        let result = dot_product_scalar(&a, &b);
362        // 1*4 + 2*5 + 3*6 = 32
363        assert!((result - 32.0).abs() < 1e-6);
364    }
365
366    #[test]
367    fn test_euclidean_scalar() {
368        let a = vec![0.0, 0.0, 0.0];
369        let b = vec![1.0, 2.0, 2.0];
370        let result = euclidean_scalar(&a, &b);
371        // sqrt(1 + 4 + 4) = 3
372        assert!((result - 3.0).abs() < 1e-6);
373    }
374
375    #[test]
376    fn test_cosine_scalar_identical() {
377        let a = vec![1.0, 2.0, 3.0];
378        let result = cosine_scalar(&a, &a);
379        assert!((result - 1.0).abs() < 1e-6);
380    }
381
382    #[test]
383    fn test_cosine_scalar_zero_norm() {
384        let a = vec![0.0, 0.0, 0.0];
385        let b = vec![1.0, 2.0, 3.0];
386        let result = cosine_scalar(&a, &b);
387        assert!((result - 0.0).abs() < 1e-6);
388    }
389
390    #[test]
391    fn test_cosine_normalized_scalar() {
392        let a = vec![1.0, 0.0];
393        let b = vec![0.0, 1.0];
394        let result = cosine_normalized_scalar(&a, &b);
395        assert!(result.abs() < 1e-6);
396    }
397
398    #[test]
399    fn test_hamming_scalar() {
400        let a = vec![1.0, 0.0, 1.0, 0.0];
401        let b = vec![0.0, 1.0, 1.0, 0.0];
402        let result = hamming_scalar(&a, &b);
403        // Position 0: 1.0 > 0.5, 0.0 < 0.5 -> different
404        // Position 1: 0.0 < 0.5, 1.0 > 0.5 -> different
405        // Position 2: same
406        // Position 3: same
407        assert_eq!(result, 2);
408    }
409
410    // =========================================================================
411    // Prefetch Distance Tests
412    // =========================================================================
413
414    #[test]
415    fn test_prefetch_distance_384d() {
416        let dist = prefetch_distance(384);
417        assert_eq!(dist, PREFETCH_DISTANCE_384D);
418        assert_eq!(dist, 24); // 384 * 4 / 64
419    }
420
421    #[test]
422    fn test_prefetch_distance_768d() {
423        let dist = prefetch_distance(768);
424        assert_eq!(dist, PREFETCH_DISTANCE_768D);
425        assert_eq!(dist, 48); // 768 * 4 / 64
426    }
427
428    #[test]
429    fn test_prefetch_distance_1536d() {
430        let dist = prefetch_distance(1536);
431        assert_eq!(dist, PREFETCH_DISTANCE_1536D);
432        assert_eq!(dist, 96); // 1536 * 4 / 64
433    }
434
435    #[test]
436    fn test_cache_line_size() {
437        assert_eq!(CACHE_LINE_SIZE, 64);
438    }
439
440    // =========================================================================
441    // Edge Cases
442    // =========================================================================
443
444    #[test]
445    #[should_panic(expected = "Vector length mismatch")]
446    fn test_dot_product_scalar_length_mismatch() {
447        let a = vec![1.0, 2.0];
448        let b = vec![1.0];
449        dot_product_scalar(&a, &b);
450    }
451
452    #[test]
453    #[should_panic(expected = "Vector length mismatch")]
454    fn test_euclidean_scalar_length_mismatch() {
455        let a = vec![1.0, 2.0];
456        let b = vec![1.0];
457        euclidean_scalar(&a, &b);
458    }
459
460    #[test]
461    #[should_panic(expected = "Vector length mismatch")]
462    fn test_cosine_scalar_length_mismatch() {
463        let a = vec![1.0, 2.0];
464        let b = vec![1.0];
465        cosine_scalar(&a, &b);
466    }
467
468    #[test]
469    #[should_panic(expected = "Vector length mismatch")]
470    fn test_hamming_scalar_length_mismatch() {
471        let a = vec![1.0, 2.0];
472        let b = vec![1.0];
473        hamming_scalar(&a, &b);
474    }
475
476    #[test]
477    fn test_empty_vectors() {
478        let a: Vec<f32> = vec![];
479        let b: Vec<f32> = vec![];
480        let dot = dot_product_scalar(&a, &b);
481        assert!((dot - 0.0).abs() < 1e-6);
482    }
483}