Skip to main content

ruvix_vecgraph/
simd_distance.rs

1//! SIMD-accelerated vector distance functions for kernel vector stores.
2//!
3//! This module provides optimized distance computations using:
4//! - NEON instructions on ARM64
5//! - AVX2/SSE instructions on x86_64
6//! - Scalar fallback for other platforms
7//!
8//! Performance targets (ADR-087):
9//! - Vector distance: SIMD acceleration with packed f32
10//! - Process 4/8 floats per cycle depending on platform
11//!
12//! # Example
13//!
14//! ```
15//! use ruvix_vecgraph::simd_distance::{cosine_similarity, euclidean_distance_squared};
16//!
17//! let a = [1.0f32, 2.0, 3.0, 4.0];
18//! let b = [4.0f32, 3.0, 2.0, 1.0];
19//!
20//! let cosine = cosine_similarity(&a, &b);
21//! let euclidean = euclidean_distance_squared(&a, &b);
22//! ```
23
24#[cfg(feature = "alloc")]
25extern crate alloc;
26
27/// SIMD capabilities detected at runtime.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct SimdCapabilities {
30    /// AVX2 is available (x86_64).
31    pub avx2: bool,
32    /// AVX-512 is available (x86_64).
33    pub avx512: bool,
34    /// NEON is available (ARM64).
35    pub neon: bool,
36    /// FMA (fused multiply-add) is available.
37    pub fma: bool,
38}
39
40impl SimdCapabilities {
41    /// Detects SIMD capabilities at runtime.
42    #[must_use]
43    pub fn detect() -> Self {
44        Self {
45            avx2: cfg!(all(target_arch = "x86_64", target_feature = "avx2")),
46            avx512: cfg!(all(target_arch = "x86_64", target_feature = "avx512f")),
47            neon: cfg!(all(target_arch = "aarch64", target_feature = "neon")),
48            fma: cfg!(target_feature = "fma"),
49        }
50    }
51
52    /// Returns the optimal vector lane width for this platform.
53    #[must_use]
54    pub const fn lane_width(&self) -> usize {
55        if self.avx512 {
56            16 // 512 bits / 32 bits per f32
57        } else if self.avx2 {
58            8 // 256 bits / 32 bits per f32
59        } else if self.neon {
60            4 // 128 bits / 32 bits per f32
61        } else {
62            1 // Scalar fallback
63        }
64    }
65
66    /// Returns true if any SIMD acceleration is available.
67    #[must_use]
68    pub const fn has_simd(&self) -> bool {
69        self.avx2 || self.avx512 || self.neon
70    }
71}
72
73impl Default for SimdCapabilities {
74    fn default() -> Self {
75        Self::detect()
76    }
77}
78
79/// Computes cosine similarity between two vectors.
80///
81/// Returns a value in [-1, 1] where 1 means identical direction.
82///
83/// # Panics
84///
85/// Panics if vectors have different lengths.
86#[inline]
87pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
88    assert_eq!(a.len(), b.len(), "Vectors must have same length");
89
90    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
91    {
92        cosine_similarity_neon(a, b)
93    }
94
95    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
96    {
97        cosine_similarity_avx2(a, b)
98    }
99
100    #[cfg(not(any(
101        all(target_arch = "aarch64", target_feature = "neon"),
102        all(target_arch = "x86_64", target_feature = "avx2")
103    )))]
104    {
105        cosine_similarity_scalar(a, b)
106    }
107}
108
109/// Computes squared Euclidean distance between two vectors.
110///
111/// Returns sum of squared differences (no sqrt for performance).
112///
113/// # Panics
114///
115/// Panics if vectors have different lengths.
116#[inline]
117pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
118    assert_eq!(a.len(), b.len(), "Vectors must have same length");
119
120    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
121    {
122        euclidean_distance_squared_neon(a, b)
123    }
124
125    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
126    {
127        euclidean_distance_squared_avx2(a, b)
128    }
129
130    #[cfg(not(any(
131        all(target_arch = "aarch64", target_feature = "neon"),
132        all(target_arch = "x86_64", target_feature = "avx2")
133    )))]
134    {
135        euclidean_distance_squared_scalar(a, b)
136    }
137}
138
139/// Computes dot product between two vectors.
140///
141/// # Panics
142///
143/// Panics if vectors have different lengths.
144#[inline]
145pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
146    assert_eq!(a.len(), b.len(), "Vectors must have same length");
147
148    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
149    {
150        dot_product_neon(a, b)
151    }
152
153    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
154    {
155        dot_product_avx2(a, b)
156    }
157
158    #[cfg(not(any(
159        all(target_arch = "aarch64", target_feature = "neon"),
160        all(target_arch = "x86_64", target_feature = "avx2")
161    )))]
162    {
163        dot_product_scalar(a, b)
164    }
165}
166
167/// Computes L2 norm (magnitude) of a vector.
168#[inline]
169pub fn l2_norm(a: &[f32]) -> f32 {
170    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
171    {
172        l2_norm_neon(a)
173    }
174
175    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
176    {
177        l2_norm_avx2(a)
178    }
179
180    #[cfg(not(any(
181        all(target_arch = "aarch64", target_feature = "neon"),
182        all(target_arch = "x86_64", target_feature = "avx2")
183    )))]
184    {
185        l2_norm_scalar(a)
186    }
187}
188
189// ============================================================================
190// ARM64 NEON Implementations
191// ============================================================================
192
193#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
194#[inline]
195fn cosine_similarity_neon(a: &[f32], b: &[f32]) -> f32 {
196    use core::arch::aarch64::*;
197
198    let n = a.len();
199    let chunks = n / 4;
200
201    unsafe {
202        let mut dot_sum = vdupq_n_f32(0.0);
203        let mut norm_a_sum = vdupq_n_f32(0.0);
204        let mut norm_b_sum = vdupq_n_f32(0.0);
205
206        for i in 0..chunks {
207            let offset = i * 4;
208            let va = vld1q_f32(a.as_ptr().add(offset));
209            let vb = vld1q_f32(b.as_ptr().add(offset));
210
211            // Fused multiply-add: dot += a * b
212            dot_sum = vfmaq_f32(dot_sum, va, vb);
213            // norm_a += a * a
214            norm_a_sum = vfmaq_f32(norm_a_sum, va, va);
215            // norm_b += b * b
216            norm_b_sum = vfmaq_f32(norm_b_sum, vb, vb);
217        }
218
219        // Horizontal sum
220        let dot = vaddvq_f32(dot_sum);
221        let norm_a = vaddvq_f32(norm_a_sum);
222        let norm_b = vaddvq_f32(norm_b_sum);
223
224        // Handle remaining elements
225        let mut dot_tail = 0.0f32;
226        let mut norm_a_tail = 0.0f32;
227        let mut norm_b_tail = 0.0f32;
228
229        for i in (chunks * 4)..n {
230            let ai = a[i];
231            let bi = b[i];
232            dot_tail += ai * bi;
233            norm_a_tail += ai * ai;
234            norm_b_tail += bi * bi;
235        }
236
237        let total_dot = dot + dot_tail;
238        let total_norm_a = (norm_a + norm_a_tail).sqrt();
239        let total_norm_b = (norm_b + norm_b_tail).sqrt();
240
241        if total_norm_a == 0.0 || total_norm_b == 0.0 {
242            0.0
243        } else {
244            total_dot / (total_norm_a * total_norm_b)
245        }
246    }
247}
248
249#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
250#[inline]
251fn euclidean_distance_squared_neon(a: &[f32], b: &[f32]) -> f32 {
252    use core::arch::aarch64::*;
253
254    let n = a.len();
255    let chunks = n / 4;
256
257    unsafe {
258        let mut sum = vdupq_n_f32(0.0);
259
260        for i in 0..chunks {
261            let offset = i * 4;
262            let va = vld1q_f32(a.as_ptr().add(offset));
263            let vb = vld1q_f32(b.as_ptr().add(offset));
264
265            // diff = a - b
266            let diff = vsubq_f32(va, vb);
267            // sum += diff * diff
268            sum = vfmaq_f32(sum, diff, diff);
269        }
270
271        // Horizontal sum
272        let mut total = vaddvq_f32(sum);
273
274        // Handle remaining elements
275        for i in (chunks * 4)..n {
276            let diff = a[i] - b[i];
277            total += diff * diff;
278        }
279
280        total
281    }
282}
283
284#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
285#[inline]
286fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
287    use core::arch::aarch64::*;
288
289    let n = a.len();
290    let chunks = n / 4;
291
292    unsafe {
293        let mut sum = vdupq_n_f32(0.0);
294
295        for i in 0..chunks {
296            let offset = i * 4;
297            let va = vld1q_f32(a.as_ptr().add(offset));
298            let vb = vld1q_f32(b.as_ptr().add(offset));
299            sum = vfmaq_f32(sum, va, vb);
300        }
301
302        let mut total = vaddvq_f32(sum);
303
304        for i in (chunks * 4)..n {
305            total += a[i] * b[i];
306        }
307
308        total
309    }
310}
311
312#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
313#[inline]
314fn l2_norm_neon(a: &[f32]) -> f32 {
315    use core::arch::aarch64::*;
316
317    let n = a.len();
318    let chunks = n / 4;
319
320    unsafe {
321        let mut sum = vdupq_n_f32(0.0);
322
323        for i in 0..chunks {
324            let offset = i * 4;
325            let va = vld1q_f32(a.as_ptr().add(offset));
326            sum = vfmaq_f32(sum, va, va);
327        }
328
329        let mut total = vaddvq_f32(sum);
330
331        for i in (chunks * 4)..n {
332            total += a[i] * a[i];
333        }
334
335        total.sqrt()
336    }
337}
338
339// ============================================================================
340// x86_64 AVX2 Implementations
341// ============================================================================
342
343#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
344#[inline]
345fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
346    use core::arch::x86_64::*;
347
348    let n = a.len();
349    let chunks = n / 8;
350
351    unsafe {
352        let mut dot_sum = _mm256_setzero_ps();
353        let mut norm_a_sum = _mm256_setzero_ps();
354        let mut norm_b_sum = _mm256_setzero_ps();
355
356        for i in 0..chunks {
357            let offset = i * 8;
358            let va = _mm256_loadu_ps(a.as_ptr().add(offset));
359            let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
360
361            // FMA: dot += a * b
362            dot_sum = _mm256_fmadd_ps(va, vb, dot_sum);
363            norm_a_sum = _mm256_fmadd_ps(va, va, norm_a_sum);
364            norm_b_sum = _mm256_fmadd_ps(vb, vb, norm_b_sum);
365        }
366
367        // Horizontal sum (AVX2)
368        let dot = horizontal_sum_avx2(dot_sum);
369        let norm_a = horizontal_sum_avx2(norm_a_sum);
370        let norm_b = horizontal_sum_avx2(norm_b_sum);
371
372        // Handle remaining elements
373        let mut dot_tail = 0.0f32;
374        let mut norm_a_tail = 0.0f32;
375        let mut norm_b_tail = 0.0f32;
376
377        for i in (chunks * 8)..n {
378            let ai = a[i];
379            let bi = b[i];
380            dot_tail += ai * bi;
381            norm_a_tail += ai * ai;
382            norm_b_tail += bi * bi;
383        }
384
385        let total_dot = dot + dot_tail;
386        let total_norm_a = (norm_a + norm_a_tail).sqrt();
387        let total_norm_b = (norm_b + norm_b_tail).sqrt();
388
389        if total_norm_a == 0.0 || total_norm_b == 0.0 {
390            0.0
391        } else {
392            total_dot / (total_norm_a * total_norm_b)
393        }
394    }
395}
396
397#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
398#[inline]
399fn euclidean_distance_squared_avx2(a: &[f32], b: &[f32]) -> f32 {
400    use core::arch::x86_64::*;
401
402    let n = a.len();
403    let chunks = n / 8;
404
405    unsafe {
406        let mut sum = _mm256_setzero_ps();
407
408        for i in 0..chunks {
409            let offset = i * 8;
410            let va = _mm256_loadu_ps(a.as_ptr().add(offset));
411            let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
412
413            let diff = _mm256_sub_ps(va, vb);
414            sum = _mm256_fmadd_ps(diff, diff, sum);
415        }
416
417        let mut total = horizontal_sum_avx2(sum);
418
419        for i in (chunks * 8)..n {
420            let diff = a[i] - b[i];
421            total += diff * diff;
422        }
423
424        total
425    }
426}
427
428#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
429#[inline]
430fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
431    use core::arch::x86_64::*;
432
433    let n = a.len();
434    let chunks = n / 8;
435
436    unsafe {
437        let mut sum = _mm256_setzero_ps();
438
439        for i in 0..chunks {
440            let offset = i * 8;
441            let va = _mm256_loadu_ps(a.as_ptr().add(offset));
442            let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
443            sum = _mm256_fmadd_ps(va, vb, sum);
444        }
445
446        let mut total = horizontal_sum_avx2(sum);
447
448        for i in (chunks * 8)..n {
449            total += a[i] * b[i];
450        }
451
452        total
453    }
454}
455
456#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
457#[inline]
458fn l2_norm_avx2(a: &[f32]) -> f32 {
459    use core::arch::x86_64::*;
460
461    let n = a.len();
462    let chunks = n / 8;
463
464    unsafe {
465        let mut sum = _mm256_setzero_ps();
466
467        for i in 0..chunks {
468            let offset = i * 8;
469            let va = _mm256_loadu_ps(a.as_ptr().add(offset));
470            sum = _mm256_fmadd_ps(va, va, sum);
471        }
472
473        let mut total = horizontal_sum_avx2(sum);
474
475        for i in (chunks * 8)..n {
476            total += a[i] * a[i];
477        }
478
479        total.sqrt()
480    }
481}
482
483#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
484#[inline]
485unsafe fn horizontal_sum_avx2(v: core::arch::x86_64::__m256) -> f32 {
486    use core::arch::x86_64::*;
487
488    // Add high 128 bits to low 128 bits
489    let high = _mm256_extractf128_ps(v, 1);
490    let low = _mm256_castps256_ps128(v);
491    let sum128 = _mm_add_ps(high, low);
492
493    // Horizontal add within 128-bit
494    let shuf = _mm_movehdup_ps(sum128);
495    let sums = _mm_add_ps(sum128, shuf);
496    let shuf2 = _mm_movehl_ps(sums, sums);
497    let sums2 = _mm_add_ss(sums, shuf2);
498
499    _mm_cvtss_f32(sums2)
500}
501
502// ============================================================================
503// Scalar Fallback Implementations
504// ============================================================================
505
506/// Scalar cosine similarity (portable fallback).
507#[inline]
508pub fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
509    let mut dot = 0.0f32;
510    let mut norm_a = 0.0f32;
511    let mut norm_b = 0.0f32;
512
513    for i in 0..a.len() {
514        let ai = a[i];
515        let bi = b[i];
516        dot += ai * bi;
517        norm_a += ai * ai;
518        norm_b += bi * bi;
519    }
520
521    let norm_a = norm_a.sqrt();
522    let norm_b = norm_b.sqrt();
523
524    if norm_a == 0.0 || norm_b == 0.0 {
525        0.0
526    } else {
527        dot / (norm_a * norm_b)
528    }
529}
530
531/// Scalar squared Euclidean distance (portable fallback).
532#[inline]
533pub fn euclidean_distance_squared_scalar(a: &[f32], b: &[f32]) -> f32 {
534    let mut sum = 0.0f32;
535    for i in 0..a.len() {
536        let diff = a[i] - b[i];
537        sum += diff * diff;
538    }
539    sum
540}
541
542/// Scalar dot product (portable fallback).
543#[inline]
544pub fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
545    let mut sum = 0.0f32;
546    for i in 0..a.len() {
547        sum += a[i] * b[i];
548    }
549    sum
550}
551
552/// Scalar L2 norm (portable fallback).
553#[inline]
554pub fn l2_norm_scalar(a: &[f32]) -> f32 {
555    let mut sum = 0.0f32;
556    for &x in a {
557        sum += x * x;
558    }
559    sum.sqrt()
560}
561
562// ============================================================================
563// Batch Distance Computation
564// ============================================================================
565
566/// Result of a batch distance computation.
567#[derive(Debug, Clone)]
568pub struct DistanceResult {
569    /// Index of the vector in the batch.
570    pub index: usize,
571    /// Distance value.
572    pub distance: f32,
573}
574
575/// Computes distances from a query vector to all vectors in a batch.
576///
577/// Returns results sorted by distance (closest first).
578///
579/// # Arguments
580///
581/// * `query` - The query vector
582/// * `batch` - Iterator over (index, vector) pairs
583/// * `k` - Maximum number of results to return
584///
585/// # Panics
586///
587/// Panics if any vector in batch has different length than query.
588#[cfg(feature = "alloc")]
589pub fn batch_cosine_distances<'a, I>(
590    query: &[f32],
591    batch: I,
592    k: usize,
593) -> alloc::vec::Vec<DistanceResult>
594where
595    I: Iterator<Item = (usize, &'a [f32])>,
596{
597    use alloc::vec::Vec;
598
599    let mut results: Vec<DistanceResult> = batch
600        .map(|(index, vector)| DistanceResult {
601            index,
602            distance: 1.0 - cosine_similarity(query, vector), // Convert to distance
603        })
604        .collect();
605
606    // Partial sort for top-k
607    if results.len() > k {
608        results.select_nth_unstable_by(k, |a, b| {
609            a.distance
610                .partial_cmp(&b.distance)
611                .unwrap_or(core::cmp::Ordering::Equal)
612        });
613        results.truncate(k);
614    }
615
616    results.sort_by(|a, b| {
617        a.distance
618            .partial_cmp(&b.distance)
619            .unwrap_or(core::cmp::Ordering::Equal)
620    });
621
622    results
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628    extern crate alloc;
629    use alloc::vec::Vec;
630
631    #[test]
632    fn test_cosine_similarity_identical() {
633        let a = [1.0f32, 2.0, 3.0, 4.0];
634        let b = [1.0f32, 2.0, 3.0, 4.0];
635        let sim = cosine_similarity(&a, &b);
636        assert!((sim - 1.0).abs() < 1e-6);
637    }
638
639    #[test]
640    fn test_cosine_similarity_opposite() {
641        let a = [1.0f32, 0.0, 0.0, 0.0];
642        let b = [-1.0f32, 0.0, 0.0, 0.0];
643        let sim = cosine_similarity(&a, &b);
644        assert!((sim - (-1.0)).abs() < 1e-6);
645    }
646
647    #[test]
648    fn test_cosine_similarity_orthogonal() {
649        let a = [1.0f32, 0.0, 0.0, 0.0];
650        let b = [0.0f32, 1.0, 0.0, 0.0];
651        let sim = cosine_similarity(&a, &b);
652        assert!(sim.abs() < 1e-6);
653    }
654
655    #[test]
656    fn test_euclidean_distance_zero() {
657        let a = [1.0f32, 2.0, 3.0, 4.0];
658        let b = [1.0f32, 2.0, 3.0, 4.0];
659        let dist = euclidean_distance_squared(&a, &b);
660        assert!(dist.abs() < 1e-6);
661    }
662
663    #[test]
664    fn test_euclidean_distance_known() {
665        let a = [0.0f32, 0.0, 0.0, 0.0];
666        let b = [3.0f32, 4.0, 0.0, 0.0];
667        let dist = euclidean_distance_squared(&a, &b);
668        assert!((dist - 25.0).abs() < 1e-6); // 3^2 + 4^2 = 25
669    }
670
671    #[test]
672    fn test_dot_product() {
673        let a = [1.0f32, 2.0, 3.0, 4.0];
674        let b = [2.0f32, 3.0, 4.0, 5.0];
675        let dot = dot_product(&a, &b);
676        assert!((dot - 40.0).abs() < 1e-6); // 2+6+12+20 = 40
677    }
678
679    #[test]
680    fn test_l2_norm() {
681        let a = [3.0f32, 4.0, 0.0, 0.0];
682        let norm = l2_norm(&a);
683        assert!((norm - 5.0).abs() < 1e-6);
684    }
685
686    #[test]
687    fn test_large_vector() {
688        // Test with 768 dimensions (common embedding size)
689        let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
690        let b: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02).collect();
691
692        let sim = cosine_similarity(&a, &b);
693        assert!(sim > 0.99); // Same direction, just different magnitude
694
695        let dot = dot_product(&a, &b);
696        assert!(dot > 0.0);
697    }
698
699    #[test]
700    fn test_scalar_matches_simd() {
701        let a: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
702        let b: Vec<f32> = (0..128).map(|i| ((i + 1) as f32) * 0.1).collect();
703
704        let scalar_sim = cosine_similarity_scalar(&a, &b);
705        let simd_sim = cosine_similarity(&a, &b);
706
707        assert!((scalar_sim - simd_sim).abs() < 1e-5);
708
709        let scalar_dist = euclidean_distance_squared_scalar(&a, &b);
710        let simd_dist = euclidean_distance_squared(&a, &b);
711
712        assert!((scalar_dist - simd_dist).abs() < 1e-4);
713    }
714}