vectx_core/
simd.rs

1// SIMD optimizations for vector operations
2// Uses platform-specific SIMD intrinsics for maximum performance
3
4#[cfg(target_arch = "x86_64")]
5use std::arch::x86_64::*;
6
7#[cfg(target_arch = "aarch64")]
8use std::arch::aarch64::*;
9
10// Minimum dimension sizes for SIMD optimization
11#[cfg(target_arch = "x86_64")]
12const MIN_DIM_SIZE_AVX: usize = 32;
13
14#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))]
15const MIN_DIM_SIZE_SIMD: usize = 16;
16
17/// SIMD-optimized dot product for cosine similarity
18/// Vectors should be normalized for cosine similarity
19/// Uses optimized scalar code with better pipelining (like Redis fallback)
20#[inline]
21pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
22    if a.len() != b.len() {
23        return 0.0;
24    }
25    
26    // Try platform-specific SIMD if available
27    #[cfg(target_arch = "x86_64")]
28    {
29        if is_x86_feature_detected!("avx2") 
30            && is_x86_feature_detected!("fma") 
31            && a.len() >= MIN_DIM_SIZE_AVX 
32        {
33            return unsafe { dot_product_avx2(a, b) };
34        }
35    }
36    
37    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
38    {
39        if is_x86_feature_detected!("sse") && a.len() >= MIN_DIM_SIZE_SIMD {
40            return unsafe { dot_product_sse(a, b) };
41        }
42    }
43    
44    #[cfg(target_arch = "aarch64")]
45    {
46        if std::arch::is_aarch64_feature_detected!("neon") && a.len() >= MIN_DIM_SIZE_SIMD {
47            return unsafe { dot_product_neon(a, b) };
48        }
49    }
50    
51    // Optimized scalar fallback (like Redis's scalar implementation)
52    // Uses two accumulators for better pipelining
53    dot_product_scalar(a, b)
54}
55
56/// AVX2-optimized dot product (16 floats at a time)
57/// Inspired by Redis's vectors_distance_float_avx2
58#[cfg(target_arch = "x86_64")]
59#[target_feature(enable = "avx2", enable = "fma")]
60#[inline]
61unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
62    let dim = a.len();
63    let mut i = 0;
64    
65    let mut sum1 = _mm256_setzero_ps();
66    let mut sum2 = _mm256_setzero_ps();
67    
68    // Process 16 floats at a time with two AVX2 registers
69    while i + 15 < dim {
70        let vx1 = _mm256_loadu_ps(a.as_ptr().add(i));
71        let vy1 = _mm256_loadu_ps(b.as_ptr().add(i));
72        let vx2 = _mm256_loadu_ps(a.as_ptr().add(i + 8));
73        let vy2 = _mm256_loadu_ps(b.as_ptr().add(i + 8));
74        
75        sum1 = _mm256_fmadd_ps(vx1, vy1, sum1);
76        sum2 = _mm256_fmadd_ps(vx2, vy2, sum2);
77        
78        i += 16;
79    }
80    
81    // Combine the two sums
82    let combined = _mm256_add_ps(sum1, sum2);
83    
84    // Horizontal sum of the 8 elements
85    let sum_high = _mm256_extractf128_ps(combined, 1);
86    let sum_low = _mm256_castps256_ps128(combined);
87    let mut sum_128 = _mm_add_ps(sum_high, sum_low);
88    
89    sum_128 = _mm_hadd_ps(sum_128, sum_128);
90    sum_128 = _mm_hadd_ps(sum_128, sum_128);
91    
92    let mut dot = _mm_cvtss_f32(sum_128);
93    
94    // Handle remaining elements
95    while i < dim {
96        dot += a[i] * b[i];
97        i += 1;
98    }
99    
100    dot
101}
102
103/// SSE-optimized dot product (qdrant compatibility pattern)
104#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
105#[target_feature(enable = "sse")]
106#[inline]
107unsafe fn dot_product_sse(a: &[f32], b: &[f32]) -> f32 {
108    #[cfg(target_arch = "x86")]
109    use std::arch::x86::*;
110    #[cfg(target_arch = "x86_64")]
111    use std::arch::x86_64::*;
112    
113    let dim = a.len();
114    let mut i = 0;
115    let mut sum = _mm_setzero_ps();
116    
117    // Process 4 floats at a time
118    while i + 3 < dim {
119        let va = _mm_loadu_ps(a.as_ptr().add(i));
120        let vb = _mm_loadu_ps(b.as_ptr().add(i));
121        sum = _mm_add_ps(sum, _mm_mul_ps(va, vb));
122        i += 4;
123    }
124    
125    // Horizontal sum
126    let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
127    sum = _mm_add_ps(sum, shuf);
128    let shuf = _mm_movehl_ps(sum, sum);
129    sum = _mm_add_ss(sum, shuf);
130    
131    let mut dot = _mm_cvtss_f32(sum);
132    
133    // Handle remaining elements
134    while i < dim {
135        dot += a[i] * b[i];
136        i += 1;
137    }
138    
139    dot
140}
141
142/// NEON-optimized dot product for ARM/Apple Silicon
143/// Uses 8-wide processing with two NEON registers for better throughput
144#[cfg(target_arch = "aarch64")]
145#[target_feature(enable = "neon")]
146#[inline]
147unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
148    let dim = a.len();
149    let mut i = 0;
150    
151    // Use two accumulators for better instruction pipelining
152    let mut sum1 = vdupq_n_f32(0.0);
153    let mut sum2 = vdupq_n_f32(0.0);
154    
155    // Process 8 floats at a time with two NEON registers
156    while i + 7 < dim {
157        let va1 = vld1q_f32(a.as_ptr().add(i));
158        let vb1 = vld1q_f32(b.as_ptr().add(i));
159        let va2 = vld1q_f32(a.as_ptr().add(i + 4));
160        let vb2 = vld1q_f32(b.as_ptr().add(i + 4));
161        
162        sum1 = vfmaq_f32(sum1, va1, vb1);
163        sum2 = vfmaq_f32(sum2, va2, vb2);
164        
165        i += 8;
166    }
167    
168    // Process remaining 4 floats
169    while i + 3 < dim {
170        let va = vld1q_f32(a.as_ptr().add(i));
171        let vb = vld1q_f32(b.as_ptr().add(i));
172        sum1 = vfmaq_f32(sum1, va, vb);
173        i += 4;
174    }
175    
176    // Combine accumulators and horizontal sum
177    let combined = vaddq_f32(sum1, sum2);
178    let mut dot = vaddvq_f32(combined);
179    
180    // Handle remaining elements
181    while i < dim {
182        dot += a[i] * b[i];
183        i += 1;
184    }
185    
186    dot
187}
188
189/// Scalar fallback (two accumulators for better pipelining - Redis pattern)
190#[inline]
191fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
192    let mut dot0 = 0.0f32;
193    let mut dot1 = 0.0f32;
194    
195    // Process 8 elements at a time with two accumulators
196    let chunks = a.chunks_exact(8);
197    let remainder = chunks.remainder();
198    let b_chunks = b.chunks_exact(8);
199    
200    for (a_chunk, b_chunk) in chunks.zip(b_chunks) {
201        dot0 += a_chunk[0] * b_chunk[0] +
202                a_chunk[1] * b_chunk[1] +
203                a_chunk[2] * b_chunk[2] +
204                a_chunk[3] * b_chunk[3];
205        
206        dot1 += a_chunk[4] * b_chunk[4] +
207                a_chunk[5] * b_chunk[5] +
208                a_chunk[6] * b_chunk[6] +
209                a_chunk[7] * b_chunk[7];
210    }
211    
212    // Handle remainder
213    for i in (a.len() - remainder.len())..a.len() {
214        dot0 += a[i] * b[i];
215    }
216    
217    dot0 + dot1
218}
219
220
221/// SIMD-optimized L2 distance (Euclidean)
222#[inline]
223pub fn l2_distance_simd(a: &[f32], b: &[f32]) -> f32 {
224    if a.len() != b.len() {
225        return f32::INFINITY;
226    }
227    
228    // Try platform-specific SIMD if available
229    #[cfg(target_arch = "x86_64")]
230    {
231        if is_x86_feature_detected!("avx2") 
232            && is_x86_feature_detected!("fma") 
233            && a.len() >= MIN_DIM_SIZE_AVX 
234        {
235            return unsafe { l2_distance_avx2(a, b) };
236        }
237    }
238    
239    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
240    {
241        if is_x86_feature_detected!("sse") && a.len() >= MIN_DIM_SIZE_SIMD {
242            return unsafe { l2_distance_sse(a, b) };
243        }
244    }
245    
246    #[cfg(target_arch = "aarch64")]
247    {
248        if std::arch::is_aarch64_feature_detected!("neon") && a.len() >= MIN_DIM_SIZE_SIMD {
249            return unsafe { l2_distance_neon(a, b) };
250        }
251    }
252    
253    l2_distance_scalar(a, b)
254}
255
256/// AVX2-optimized L2 distance
257#[cfg(target_arch = "x86_64")]
258#[target_feature(enable = "avx2", enable = "fma")]
259#[inline]
260unsafe fn l2_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
261    let dim = a.len();
262    let mut i = 0;
263    
264    let mut sum1 = _mm256_setzero_ps();
265    let mut sum2 = _mm256_setzero_ps();
266    
267    // Process 16 floats at a time with two AVX2 registers
268    while i + 15 < dim {
269        let va1 = _mm256_loadu_ps(a.as_ptr().add(i));
270        let vb1 = _mm256_loadu_ps(b.as_ptr().add(i));
271        let va2 = _mm256_loadu_ps(a.as_ptr().add(i + 8));
272        let vb2 = _mm256_loadu_ps(b.as_ptr().add(i + 8));
273        
274        let diff1 = _mm256_sub_ps(va1, vb1);
275        let diff2 = _mm256_sub_ps(va2, vb2);
276        
277        sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
278        sum2 = _mm256_fmadd_ps(diff2, diff2, sum2);
279        
280        i += 16;
281    }
282    
283    // Combine the two sums
284    let combined = _mm256_add_ps(sum1, sum2);
285    
286    // Horizontal sum of the 8 elements
287    let sum_high = _mm256_extractf128_ps(combined, 1);
288    let sum_low = _mm256_castps256_ps128(combined);
289    let mut sum_128 = _mm_add_ps(sum_high, sum_low);
290    
291    sum_128 = _mm_hadd_ps(sum_128, sum_128);
292    sum_128 = _mm_hadd_ps(sum_128, sum_128);
293    
294    let mut sum_sq = _mm_cvtss_f32(sum_128);
295    
296    // Handle remaining elements
297    while i < dim {
298        let diff = a[i] - b[i];
299        sum_sq += diff * diff;
300        i += 1;
301    }
302    
303    sum_sq.sqrt()
304}
305
306/// SSE-optimized L2 distance
307#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
308#[target_feature(enable = "sse")]
309#[inline]
310unsafe fn l2_distance_sse(a: &[f32], b: &[f32]) -> f32 {
311    #[cfg(target_arch = "x86")]
312    use std::arch::x86::*;
313    #[cfg(target_arch = "x86_64")]
314    use std::arch::x86_64::*;
315    
316    let dim = a.len();
317    let mut i = 0;
318    let mut sum = _mm_setzero_ps();
319    
320    // Process 4 floats at a time
321    while i + 3 < dim {
322        let va = _mm_loadu_ps(a.as_ptr().add(i));
323        let vb = _mm_loadu_ps(b.as_ptr().add(i));
324        let diff = _mm_sub_ps(va, vb);
325        sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
326        i += 4;
327    }
328    
329    // Horizontal sum
330    let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
331    sum = _mm_add_ps(sum, shuf);
332    let shuf = _mm_movehl_ps(sum, sum);
333    sum = _mm_add_ss(sum, shuf);
334    
335    let mut sum_sq = _mm_cvtss_f32(sum);
336    
337    // Handle remaining elements
338    while i < dim {
339        let diff = a[i] - b[i];
340        sum_sq += diff * diff;
341        i += 1;
342    }
343    
344    sum_sq.sqrt()
345}
346
347/// NEON-optimized L2 distance for ARM/Apple Silicon
348/// Uses 8-wide processing with two NEON registers for better throughput
349#[cfg(target_arch = "aarch64")]
350#[target_feature(enable = "neon")]
351#[inline]
352unsafe fn l2_distance_neon(a: &[f32], b: &[f32]) -> f32 {
353    let dim = a.len();
354    let mut i = 0;
355    
356    // Use two accumulators for better instruction pipelining
357    let mut sum1 = vdupq_n_f32(0.0);
358    let mut sum2 = vdupq_n_f32(0.0);
359    
360    // Process 8 floats at a time with two NEON registers
361    while i + 7 < dim {
362        let va1 = vld1q_f32(a.as_ptr().add(i));
363        let vb1 = vld1q_f32(b.as_ptr().add(i));
364        let va2 = vld1q_f32(a.as_ptr().add(i + 4));
365        let vb2 = vld1q_f32(b.as_ptr().add(i + 4));
366        
367        let diff1 = vsubq_f32(va1, vb1);
368        let diff2 = vsubq_f32(va2, vb2);
369        
370        sum1 = vfmaq_f32(sum1, diff1, diff1);
371        sum2 = vfmaq_f32(sum2, diff2, diff2);
372        
373        i += 8;
374    }
375    
376    // Process remaining 4 floats
377    while i + 3 < dim {
378        let va = vld1q_f32(a.as_ptr().add(i));
379        let vb = vld1q_f32(b.as_ptr().add(i));
380        let diff = vsubq_f32(va, vb);
381        sum1 = vfmaq_f32(sum1, diff, diff);
382        i += 4;
383    }
384    
385    // Combine accumulators and horizontal sum
386    let combined = vaddq_f32(sum1, sum2);
387    let mut sum_sq = vaddvq_f32(combined);
388    
389    // Handle remaining elements
390    while i < dim {
391        let diff = a[i] - b[i];
392        sum_sq += diff * diff;
393        i += 1;
394    }
395    
396    sum_sq.sqrt()
397}
398
399/// Scalar L2 distance (two accumulators for better pipelining)
400#[inline]
401fn l2_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
402    let mut sum0 = 0.0f32;
403    let mut sum1 = 0.0f32;
404    
405    // Process 4 elements at a time with two accumulators
406    let chunks = a.chunks_exact(4);
407    let remainder = chunks.remainder();
408    let b_chunks = b.chunks_exact(4);
409    
410    for (a_chunk, b_chunk) in chunks.zip(b_chunks) {
411        let d0 = a_chunk[0] - b_chunk[0];
412        let d1 = a_chunk[1] - b_chunk[1];
413        let d2 = a_chunk[2] - b_chunk[2];
414        let d3 = a_chunk[3] - b_chunk[3];
415        
416        sum0 += d0 * d0 + d1 * d1;
417        sum1 += d2 * d2 + d3 * d3;
418    }
419    
420    // Handle remainder
421    for i in (a.len() - remainder.len())..a.len() {
422        let diff = a[i] - b[i];
423        sum0 += diff * diff;
424    }
425    
426    (sum0 + sum1).sqrt()
427}
428
429/// SIMD-optimized vector norm (squared length)
430#[inline]
431pub fn norm_squared_simd(v: &[f32]) -> f32 {
432    dot_product_simd(v, v)
433}
434
435/// SIMD-optimized vector norm (length)
436#[inline]
437pub fn norm_simd(v: &[f32]) -> f32 {
438    norm_squared_simd(v).sqrt()
439}
440