ruvector_core/
simd_intrinsics.rs

1//! Custom SIMD intrinsics for performance-critical operations
2//!
3//! This module provides hand-optimized SIMD implementations using AVX2/AVX-512
4//! for distance calculations and other vectorized operations.
5
6#[cfg(target_arch = "x86_64")]
7use std::arch::x86_64::*;
8
9/// SIMD-optimized euclidean distance using AVX2
10/// Falls back to scalar implementation if AVX2 is not available
11#[inline]
12pub fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
13    #[cfg(target_arch = "x86_64")]
14    {
15        if is_x86_feature_detected!("avx2") {
16            unsafe { euclidean_distance_avx2_impl(a, b) }
17        } else {
18            euclidean_distance_scalar(a, b)
19        }
20    }
21
22    #[cfg(not(target_arch = "x86_64"))]
23    {
24        euclidean_distance_scalar(a, b)
25    }
26}
27
28#[cfg(target_arch = "x86_64")]
29#[target_feature(enable = "avx2")]
30unsafe fn euclidean_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
31    let len = a.len();
32    let mut sum = _mm256_setzero_ps();
33
34    // Process 8 floats at a time
35    let chunks = len / 8;
36    for i in 0..chunks {
37        let idx = i * 8;
38
39        // Load 8 floats from each array
40        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
41        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
42
43        // Compute difference: (a - b)
44        let diff = _mm256_sub_ps(va, vb);
45
46        // Square the difference: (a - b)^2
47        let sq = _mm256_mul_ps(diff, diff);
48
49        // Accumulate
50        sum = _mm256_add_ps(sum, sq);
51    }
52
53    // Horizontal sum of the 8 floats in the AVX register
54    let sum_arr: [f32; 8] = std::mem::transmute(sum);
55    let mut total = sum_arr.iter().sum::<f32>();
56
57    // Handle remaining elements (if len not divisible by 8)
58    for i in (chunks * 8)..len {
59        let diff = a[i] - b[i];
60        total += diff * diff;
61    }
62
63    total.sqrt()
64}
65
66/// SIMD-optimized dot product using AVX2
67#[inline]
68pub fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
69    #[cfg(target_arch = "x86_64")]
70    {
71        if is_x86_feature_detected!("avx2") {
72            unsafe { dot_product_avx2_impl(a, b) }
73        } else {
74            dot_product_scalar(a, b)
75        }
76    }
77
78    #[cfg(not(target_arch = "x86_64"))]
79    {
80        dot_product_scalar(a, b)
81    }
82}
83
84#[cfg(target_arch = "x86_64")]
85#[target_feature(enable = "avx2")]
86unsafe fn dot_product_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
87    let len = a.len();
88    let mut sum = _mm256_setzero_ps();
89
90    let chunks = len / 8;
91    for i in 0..chunks {
92        let idx = i * 8;
93        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
94        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
95        let prod = _mm256_mul_ps(va, vb);
96        sum = _mm256_add_ps(sum, prod);
97    }
98
99    let sum_arr: [f32; 8] = std::mem::transmute(sum);
100    let mut total = sum_arr.iter().sum::<f32>();
101
102    for i in (chunks * 8)..len {
103        total += a[i] * b[i];
104    }
105
106    total
107}
108
109/// SIMD-optimized cosine similarity using AVX2
110#[inline]
111pub fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
112    #[cfg(target_arch = "x86_64")]
113    {
114        if is_x86_feature_detected!("avx2") {
115            unsafe { cosine_similarity_avx2_impl(a, b) }
116        } else {
117            cosine_similarity_scalar(a, b)
118        }
119    }
120
121    #[cfg(not(target_arch = "x86_64"))]
122    {
123        cosine_similarity_scalar(a, b)
124    }
125}
126
127#[cfg(target_arch = "x86_64")]
128#[target_feature(enable = "avx2")]
129unsafe fn cosine_similarity_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
130    let len = a.len();
131    let mut dot = _mm256_setzero_ps();
132    let mut norm_a = _mm256_setzero_ps();
133    let mut norm_b = _mm256_setzero_ps();
134
135    let chunks = len / 8;
136    for i in 0..chunks {
137        let idx = i * 8;
138        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
139        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
140
141        // Dot product
142        dot = _mm256_add_ps(dot, _mm256_mul_ps(va, vb));
143
144        // Norms
145        norm_a = _mm256_add_ps(norm_a, _mm256_mul_ps(va, va));
146        norm_b = _mm256_add_ps(norm_b, _mm256_mul_ps(vb, vb));
147    }
148
149    let dot_arr: [f32; 8] = std::mem::transmute(dot);
150    let norm_a_arr: [f32; 8] = std::mem::transmute(norm_a);
151    let norm_b_arr: [f32; 8] = std::mem::transmute(norm_b);
152
153    let mut dot_sum = dot_arr.iter().sum::<f32>();
154    let mut norm_a_sum = norm_a_arr.iter().sum::<f32>();
155    let mut norm_b_sum = norm_b_arr.iter().sum::<f32>();
156
157    for i in (chunks * 8)..len {
158        dot_sum += a[i] * b[i];
159        norm_a_sum += a[i] * a[i];
160        norm_b_sum += b[i] * b[i];
161    }
162
163    dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
164}
165
166// Scalar fallback implementations
167
168fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
169    a.iter()
170        .zip(b.iter())
171        .map(|(x, y)| {
172            let diff = x - y;
173            diff * diff
174        })
175        .sum::<f32>()
176        .sqrt()
177}
178
179fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
180    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
181}
182
183fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
184    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
185    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
186    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
187    dot / (norm_a * norm_b)
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn test_euclidean_distance_avx2() {
196        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
197        let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
198
199        let result = euclidean_distance_avx2(&a, &b);
200        let expected = euclidean_distance_scalar(&a, &b);
201
202        assert!(
203            (result - expected).abs() < 0.001,
204            "AVX2 result {} differs from scalar result {}",
205            result,
206            expected
207        );
208    }
209
210    #[test]
211    fn test_dot_product_avx2() {
212        let a = vec![1.0; 16];
213        let b = vec![2.0; 16];
214
215        let result = dot_product_avx2(&a, &b);
216        assert!((result - 32.0).abs() < 0.001);
217    }
218
219    #[test]
220    fn test_cosine_similarity_avx2() {
221        let a = vec![1.0, 0.0, 0.0];
222        let b = vec![1.0, 0.0, 0.0];
223
224        let result = cosine_similarity_avx2(&a, &b);
225        assert!((result - 1.0).abs() < 0.001);
226    }
227}