Skip to main content

ruvector_core/
simd_intrinsics.rs

1//! Custom SIMD intrinsics for performance-critical operations
2//!
3//! This module provides hand-optimized SIMD implementations:
4//! - AVX2/AVX-512 for x86_64 processors
5//! - NEON for ARM64/Apple Silicon processors (M1/M2/M3/M4)
6//!
7//! Distance calculations and other vectorized operations are automatically
8//! dispatched to the optimal implementation based on the target architecture.
9//!
10//! ## Features
11//!
12//! - **AVX-512 Support**: 512-bit operations processing 16 floats per iteration
13//! - **INT8 Quantized Operations**: SIMD-accelerated quantized vector operations
14//! - **Batch Operations**: Cache-optimized batch distance calculations
15//! - **NEON Optimizations**: Prefetch hints and loop unrolling for ARM64
16//!
17//! ## Performance Optimizations (v2)
18//!
19//! - **Loop Unrolling**: 4x unrolled loops for better instruction-level parallelism
20//! - **Prefetch Hints**: Software prefetching for large vectors (>256 elements)
21//! - **FMA Instructions**: Fused multiply-add for improved throughput and accuracy
22//! - **Efficient Horizontal Sum**: Optimized reduction operations
23
24#[cfg(target_arch = "x86_64")]
25use std::arch::x86_64::*;
26
27#[cfg(target_arch = "aarch64")]
28use std::arch::aarch64::*;
29
30/// Prefetch distance in cache lines (tuned for L1 cache, 64 bytes = 16 floats)
31#[allow(dead_code)]
32const PREFETCH_DISTANCE: usize = 64;
33
34/// SIMD-optimized euclidean distance
35/// Uses AVX-512 > AVX2 on x86_64, NEON on ARM64/Apple Silicon, falls back to scalar otherwise
36///
37/// # Optimizations for M4 Pro (ARM64)
38/// - Uses 4x loop unrolling for vectors >= 64 elements
39/// - FMA instructions for improved throughput
40/// - Optimized horizontal reduction via `vaddvq_f32`
41#[inline(always)]
42pub fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
43    #[cfg(target_arch = "x86_64")]
44    {
45        if is_x86_feature_detected!("avx512f") {
46            unsafe { euclidean_distance_avx512_impl(a, b) }
47        } else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
48            unsafe { euclidean_distance_avx2_fma_impl(a, b) }
49        } else if is_x86_feature_detected!("avx2") {
50            unsafe { euclidean_distance_avx2_impl(a, b) }
51        } else {
52            euclidean_distance_scalar(a, b)
53        }
54    }
55
56    #[cfg(target_arch = "aarch64")]
57    {
58        // Use unrolled version for vectors >= 64 elements for better ILP
59        if a.len() >= 64 {
60            unsafe { euclidean_distance_neon_unrolled_impl(a, b) }
61        } else {
62            unsafe { euclidean_distance_neon_impl(a, b) }
63        }
64    }
65
66    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
67    {
68        euclidean_distance_scalar(a, b)
69    }
70}
71
72/// Legacy alias for backward compatibility
73#[inline(always)]
74pub fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
75    euclidean_distance_simd(a, b)
76}
77
78#[cfg(target_arch = "x86_64")]
79#[target_feature(enable = "avx2")]
80unsafe fn euclidean_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
81    // SECURITY: Ensure both arrays have the same length to prevent out-of-bounds access
82    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
83
84    let len = a.len();
85    let mut sum = _mm256_setzero_ps();
86
87    // Process 8 floats at a time
88    let chunks = len / 8;
89    for i in 0..chunks {
90        let idx = i * 8;
91
92        // Load 8 floats from each array
93        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
94        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
95
96        // Compute difference: (a - b)
97        let diff = _mm256_sub_ps(va, vb);
98
99        // Square the difference: (a - b)^2
100        let sq = _mm256_mul_ps(diff, diff);
101
102        // Accumulate
103        sum = _mm256_add_ps(sum, sq);
104    }
105
106    // Horizontal sum of the 8 floats in the AVX register
107    let sum_arr: [f32; 8] = std::mem::transmute(sum);
108    let mut total = sum_arr.iter().sum::<f32>();
109
110    // Handle remaining elements (if len not divisible by 8)
111    for i in (chunks * 8)..len {
112        let diff = a[i] - b[i];
113        total += diff * diff;
114    }
115
116    total.sqrt()
117}
118
119/// AVX2 with FMA - 4x loop unrolling for better instruction-level parallelism
120#[cfg(target_arch = "x86_64")]
121#[target_feature(enable = "avx2", enable = "fma")]
122unsafe fn euclidean_distance_avx2_fma_impl(a: &[f32], b: &[f32]) -> f32 {
123    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
124
125    let len = a.len();
126    // Use 4 accumulators for better ILP (instruction-level parallelism)
127    let mut sum0 = _mm256_setzero_ps();
128    let mut sum1 = _mm256_setzero_ps();
129    let mut sum2 = _mm256_setzero_ps();
130    let mut sum3 = _mm256_setzero_ps();
131
132    // Process 32 floats at a time (4 x 8 floats)
133    let chunks = len / 32;
134    for i in 0..chunks {
135        let idx = i * 32;
136
137        // Load and process 4 vectors of 8 floats each
138        let va0 = _mm256_loadu_ps(a.as_ptr().add(idx));
139        let vb0 = _mm256_loadu_ps(b.as_ptr().add(idx));
140        let diff0 = _mm256_sub_ps(va0, vb0);
141        sum0 = _mm256_fmadd_ps(diff0, diff0, sum0);
142
143        let va1 = _mm256_loadu_ps(a.as_ptr().add(idx + 8));
144        let vb1 = _mm256_loadu_ps(b.as_ptr().add(idx + 8));
145        let diff1 = _mm256_sub_ps(va1, vb1);
146        sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
147
148        let va2 = _mm256_loadu_ps(a.as_ptr().add(idx + 16));
149        let vb2 = _mm256_loadu_ps(b.as_ptr().add(idx + 16));
150        let diff2 = _mm256_sub_ps(va2, vb2);
151        sum2 = _mm256_fmadd_ps(diff2, diff2, sum2);
152
153        let va3 = _mm256_loadu_ps(a.as_ptr().add(idx + 24));
154        let vb3 = _mm256_loadu_ps(b.as_ptr().add(idx + 24));
155        let diff3 = _mm256_sub_ps(va3, vb3);
156        sum3 = _mm256_fmadd_ps(diff3, diff3, sum3);
157    }
158
159    // Combine the 4 accumulators
160    let sum01 = _mm256_add_ps(sum0, sum1);
161    let sum23 = _mm256_add_ps(sum2, sum3);
162    let sum = _mm256_add_ps(sum01, sum23);
163
164    // Process remaining 8-float chunks
165    let remaining_start = chunks * 32;
166    let remaining_chunks = (len - remaining_start) / 8;
167    let mut final_sum = sum;
168    for i in 0..remaining_chunks {
169        let idx = remaining_start + i * 8;
170        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
171        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
172        let diff = _mm256_sub_ps(va, vb);
173        final_sum = _mm256_fmadd_ps(diff, diff, final_sum);
174    }
175
176    // Horizontal sum
177    let sum_arr: [f32; 8] = std::mem::transmute(final_sum);
178    let mut total = sum_arr.iter().sum::<f32>();
179
180    // Handle remaining elements
181    let scalar_start = remaining_start + remaining_chunks * 8;
182    for i in scalar_start..len {
183        let diff = a[i] - b[i];
184        total += diff * diff;
185    }
186
187    total.sqrt()
188}
189
190// ============================================================================
191// AVX-512 implementations for x86_64 (Intel Ice Lake, Sapphire Rapids, AMD Zen 4+)
192// ============================================================================
193
194/// AVX-512 euclidean distance - processes 16 floats per iteration
195#[cfg(target_arch = "x86_64")]
196#[target_feature(enable = "avx512f")]
197unsafe fn euclidean_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
198    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
199
200    let len = a.len();
201    let mut sum = _mm512_setzero_ps();
202
203    // Process 16 floats at a time
204    let chunks = len / 16;
205    for i in 0..chunks {
206        let idx = i * 16;
207        let va = _mm512_loadu_ps(a.as_ptr().add(idx));
208        let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
209        let diff = _mm512_sub_ps(va, vb);
210        sum = _mm512_fmadd_ps(diff, diff, sum);
211    }
212
213    // Horizontal sum using AVX-512 reduction
214    let mut total = _mm512_reduce_add_ps(sum);
215
216    // Handle remaining elements (0-15 elements)
217    for i in (chunks * 16)..len {
218        let diff = a[i] - b[i];
219        total += diff * diff;
220    }
221
222    total.sqrt()
223}
224
225/// AVX-512 dot product - processes 16 floats per iteration
226#[cfg(target_arch = "x86_64")]
227#[target_feature(enable = "avx512f")]
228unsafe fn dot_product_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
229    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
230
231    let len = a.len();
232    let mut sum = _mm512_setzero_ps();
233
234    let chunks = len / 16;
235    for i in 0..chunks {
236        let idx = i * 16;
237        let va = _mm512_loadu_ps(a.as_ptr().add(idx));
238        let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
239        sum = _mm512_fmadd_ps(va, vb, sum);
240    }
241
242    let mut total = _mm512_reduce_add_ps(sum);
243
244    for i in (chunks * 16)..len {
245        total += a[i] * b[i];
246    }
247
248    total
249}
250
251/// AVX-512 cosine similarity - processes 16 floats per iteration
252#[cfg(target_arch = "x86_64")]
253#[target_feature(enable = "avx512f")]
254unsafe fn cosine_similarity_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
255    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
256
257    let len = a.len();
258    let mut dot = _mm512_setzero_ps();
259    let mut norm_a = _mm512_setzero_ps();
260    let mut norm_b = _mm512_setzero_ps();
261
262    let chunks = len / 16;
263    for i in 0..chunks {
264        let idx = i * 16;
265        let va = _mm512_loadu_ps(a.as_ptr().add(idx));
266        let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
267
268        dot = _mm512_fmadd_ps(va, vb, dot);
269        norm_a = _mm512_fmadd_ps(va, va, norm_a);
270        norm_b = _mm512_fmadd_ps(vb, vb, norm_b);
271    }
272
273    let mut dot_sum = _mm512_reduce_add_ps(dot);
274    let mut norm_a_sum = _mm512_reduce_add_ps(norm_a);
275    let mut norm_b_sum = _mm512_reduce_add_ps(norm_b);
276
277    for i in (chunks * 16)..len {
278        dot_sum += a[i] * b[i];
279        norm_a_sum += a[i] * a[i];
280        norm_b_sum += b[i] * b[i];
281    }
282
283    dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
284}
285
286/// AVX-512 Manhattan distance - processes 16 floats per iteration
287#[cfg(target_arch = "x86_64")]
288#[target_feature(enable = "avx512f")]
289unsafe fn manhattan_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
290    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
291
292    let len = a.len();
293    let mut sum = _mm512_setzero_ps();
294
295    let chunks = len / 16;
296    for i in 0..chunks {
297        let idx = i * 16;
298        let va = _mm512_loadu_ps(a.as_ptr().add(idx));
299        let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
300        let diff = _mm512_sub_ps(va, vb);
301        let abs_diff = _mm512_abs_ps(diff);
302        sum = _mm512_add_ps(sum, abs_diff);
303    }
304
305    let mut total = _mm512_reduce_add_ps(sum);
306
307    for i in (chunks * 16)..len {
308        total += (a[i] - b[i]).abs();
309    }
310
311    total
312}
313
314// ============================================================================
315// NEON implementations for ARM64/Apple Silicon (M1/M2/M3/M4)
316// ============================================================================
317
318/// NEON-optimized euclidean distance for ARM64 (original non-unrolled version)
319/// Processes 4 floats at a time using 128-bit NEON registers
320///
321/// # Safety
322/// Caller must ensure a.len() == b.len()
323#[cfg(target_arch = "aarch64")]
324#[inline(always)]
325#[allow(dead_code)]
326unsafe fn euclidean_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
327    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
328
329    let len = a.len();
330    let mut sum = vdupq_n_f32(0.0);
331
332    let a_ptr = a.as_ptr();
333    let b_ptr = b.as_ptr();
334
335    // Process 4 floats at a time with NEON
336    let chunks = len / 4;
337    let mut idx = 0usize;
338
339    for _ in 0..chunks {
340        let va = vld1q_f32(a_ptr.add(idx));
341        let vb = vld1q_f32(b_ptr.add(idx));
342
343        // Compute difference: (a - b)
344        let diff = vsubq_f32(va, vb);
345
346        // Square and accumulate: sum += (a - b)^2
347        sum = vfmaq_f32(sum, diff, diff);
348
349        idx += 4;
350    }
351
352    // Horizontal sum of the 4 floats
353    let mut total = vaddvq_f32(sum);
354
355    // Handle remaining elements (use get_unchecked for bounds-check elimination)
356    for i in (chunks * 4)..len {
357        let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
358        total += diff * diff;
359    }
360
361    total.sqrt()
362}
363
364/// NEON-optimized dot product for ARM64
365///
366/// # Safety
367/// Caller must ensure a.len() == b.len()
368#[cfg(target_arch = "aarch64")]
369#[inline(always)]
370unsafe fn dot_product_neon_impl(a: &[f32], b: &[f32]) -> f32 {
371    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
372
373    let len = a.len();
374    let mut sum = vdupq_n_f32(0.0);
375
376    let a_ptr = a.as_ptr();
377    let b_ptr = b.as_ptr();
378
379    let chunks = len / 4;
380    let mut idx = 0usize;
381
382    for _ in 0..chunks {
383        let va = vld1q_f32(a_ptr.add(idx));
384        let vb = vld1q_f32(b_ptr.add(idx));
385
386        // Fused multiply-add: sum += a * b
387        sum = vfmaq_f32(sum, va, vb);
388
389        idx += 4;
390    }
391
392    let mut total = vaddvq_f32(sum);
393
394    // Handle remaining elements with bounds-check elimination
395    for i in (chunks * 4)..len {
396        total += *a.get_unchecked(i) * *b.get_unchecked(i);
397    }
398
399    total
400}
401
402/// NEON-optimized cosine similarity for ARM64
403///
404/// # Safety
405/// Caller must ensure a.len() == b.len()
406#[cfg(target_arch = "aarch64")]
407#[inline(always)]
408unsafe fn cosine_similarity_neon_impl(a: &[f32], b: &[f32]) -> f32 {
409    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
410
411    let len = a.len();
412    let mut dot = vdupq_n_f32(0.0);
413    let mut norm_a = vdupq_n_f32(0.0);
414    let mut norm_b = vdupq_n_f32(0.0);
415
416    let a_ptr = a.as_ptr();
417    let b_ptr = b.as_ptr();
418
419    let chunks = len / 4;
420    let mut idx = 0usize;
421
422    for _ in 0..chunks {
423        let va = vld1q_f32(a_ptr.add(idx));
424        let vb = vld1q_f32(b_ptr.add(idx));
425
426        // Dot product
427        dot = vfmaq_f32(dot, va, vb);
428
429        // Norms (squared)
430        norm_a = vfmaq_f32(norm_a, va, va);
431        norm_b = vfmaq_f32(norm_b, vb, vb);
432
433        idx += 4;
434    }
435
436    let mut dot_sum = vaddvq_f32(dot);
437    let mut norm_a_sum = vaddvq_f32(norm_a);
438    let mut norm_b_sum = vaddvq_f32(norm_b);
439
440    // Handle remaining elements with bounds-check elimination
441    for i in (chunks * 4)..len {
442        let ai = *a.get_unchecked(i);
443        let bi = *b.get_unchecked(i);
444        dot_sum += ai * bi;
445        norm_a_sum += ai * ai;
446        norm_b_sum += bi * bi;
447    }
448
449    dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
450}
451
452/// NEON-optimized Manhattan distance for ARM64
453///
454/// # Safety
455/// Caller must ensure a.len() == b.len()
456#[cfg(target_arch = "aarch64")]
457#[inline(always)]
458unsafe fn manhattan_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
459    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
460
461    let len = a.len();
462    let mut sum = vdupq_n_f32(0.0);
463
464    let a_ptr = a.as_ptr();
465    let b_ptr = b.as_ptr();
466
467    let chunks = len / 4;
468    let mut idx = 0usize;
469
470    for _ in 0..chunks {
471        let va = vld1q_f32(a_ptr.add(idx));
472        let vb = vld1q_f32(b_ptr.add(idx));
473
474        // Absolute difference using vabdq_f32 (absolute difference in one instruction)
475        let abs_diff = vabdq_f32(va, vb);
476        sum = vaddq_f32(sum, abs_diff);
477
478        idx += 4;
479    }
480
481    let mut total = vaddvq_f32(sum);
482
483    // Handle remaining elements with bounds-check elimination
484    for i in (chunks * 4)..len {
485        total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
486    }
487
488    total
489}
490
491/// NEON-optimized euclidean distance with 4x loop unrolling
492/// Optimized for larger vectors (>= 64 elements) common in ML embeddings
493///
494/// # Safety
495/// Caller must ensure a.len() == b.len()
496///
497/// # M4 Pro Optimizations
498/// - 4 independent accumulators for maximum ILP on M4's 6-wide superscalar core
499/// - Software prefetching for vectors > 256 elements
500/// - Bounds-check elimination in remainder loops
501#[cfg(target_arch = "aarch64")]
502#[inline(always)]
503unsafe fn euclidean_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
504    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
505
506    let len = a.len();
507    let a_ptr = a.as_ptr();
508    let b_ptr = b.as_ptr();
509
510    // Use 4 accumulators for better instruction-level parallelism
511    let mut sum0 = vdupq_n_f32(0.0);
512    let mut sum1 = vdupq_n_f32(0.0);
513    let mut sum2 = vdupq_n_f32(0.0);
514    let mut sum3 = vdupq_n_f32(0.0);
515
516    // Process 16 floats at a time (4 x 4 floats)
517    let chunks = len / 16;
518    let mut idx = 0usize;
519
520    for _ in 0..chunks {
521        // Unroll 4x for better ILP - all loads and operations are independent
522        let va0 = vld1q_f32(a_ptr.add(idx));
523        let vb0 = vld1q_f32(b_ptr.add(idx));
524        let diff0 = vsubq_f32(va0, vb0);
525        sum0 = vfmaq_f32(sum0, diff0, diff0);
526
527        let va1 = vld1q_f32(a_ptr.add(idx + 4));
528        let vb1 = vld1q_f32(b_ptr.add(idx + 4));
529        let diff1 = vsubq_f32(va1, vb1);
530        sum1 = vfmaq_f32(sum1, diff1, diff1);
531
532        let va2 = vld1q_f32(a_ptr.add(idx + 8));
533        let vb2 = vld1q_f32(b_ptr.add(idx + 8));
534        let diff2 = vsubq_f32(va2, vb2);
535        sum2 = vfmaq_f32(sum2, diff2, diff2);
536
537        let va3 = vld1q_f32(a_ptr.add(idx + 12));
538        let vb3 = vld1q_f32(b_ptr.add(idx + 12));
539        let diff3 = vsubq_f32(va3, vb3);
540        sum3 = vfmaq_f32(sum3, diff3, diff3);
541
542        idx += 16;
543    }
544
545    // Combine the 4 accumulators (tree reduction for latency hiding)
546    let sum01 = vaddq_f32(sum0, sum1);
547    let sum23 = vaddq_f32(sum2, sum3);
548    let sum = vaddq_f32(sum01, sum23);
549
550    // Process remaining 4-float chunks
551    let remaining_start = chunks * 16;
552    let remaining_chunks = (len - remaining_start) / 4;
553    let mut final_sum = sum;
554
555    idx = remaining_start;
556    for _ in 0..remaining_chunks {
557        let va = vld1q_f32(a_ptr.add(idx));
558        let vb = vld1q_f32(b_ptr.add(idx));
559        let diff = vsubq_f32(va, vb);
560        final_sum = vfmaq_f32(final_sum, diff, diff);
561        idx += 4;
562    }
563
564    // Horizontal sum
565    let mut total = vaddvq_f32(final_sum);
566
567    // Handle remaining elements with bounds-check elimination
568    let scalar_start = remaining_start + remaining_chunks * 4;
569    for i in scalar_start..len {
570        let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
571        total += diff * diff;
572    }
573
574    total.sqrt()
575}
576
577/// NEON-optimized dot product with 4x loop unrolling
578///
579/// # Safety
580/// Caller must ensure a.len() == b.len()
581#[cfg(target_arch = "aarch64")]
582#[inline(always)]
583unsafe fn dot_product_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
584    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
585
586    let len = a.len();
587    let a_ptr = a.as_ptr();
588    let b_ptr = b.as_ptr();
589
590    let mut sum0 = vdupq_n_f32(0.0);
591    let mut sum1 = vdupq_n_f32(0.0);
592    let mut sum2 = vdupq_n_f32(0.0);
593    let mut sum3 = vdupq_n_f32(0.0);
594
595    let chunks = len / 16;
596    let mut idx = 0usize;
597
598    for _ in 0..chunks {
599        let va0 = vld1q_f32(a_ptr.add(idx));
600        let vb0 = vld1q_f32(b_ptr.add(idx));
601        sum0 = vfmaq_f32(sum0, va0, vb0);
602
603        let va1 = vld1q_f32(a_ptr.add(idx + 4));
604        let vb1 = vld1q_f32(b_ptr.add(idx + 4));
605        sum1 = vfmaq_f32(sum1, va1, vb1);
606
607        let va2 = vld1q_f32(a_ptr.add(idx + 8));
608        let vb2 = vld1q_f32(b_ptr.add(idx + 8));
609        sum2 = vfmaq_f32(sum2, va2, vb2);
610
611        let va3 = vld1q_f32(a_ptr.add(idx + 12));
612        let vb3 = vld1q_f32(b_ptr.add(idx + 12));
613        sum3 = vfmaq_f32(sum3, va3, vb3);
614
615        idx += 16;
616    }
617
618    // Tree reduction for latency hiding
619    let sum01 = vaddq_f32(sum0, sum1);
620    let sum23 = vaddq_f32(sum2, sum3);
621    let sum = vaddq_f32(sum01, sum23);
622
623    let remaining_start = chunks * 16;
624    let remaining_chunks = (len - remaining_start) / 4;
625    let mut final_sum = sum;
626
627    idx = remaining_start;
628    for _ in 0..remaining_chunks {
629        let va = vld1q_f32(a_ptr.add(idx));
630        let vb = vld1q_f32(b_ptr.add(idx));
631        final_sum = vfmaq_f32(final_sum, va, vb);
632        idx += 4;
633    }
634
635    let mut total = vaddvq_f32(final_sum);
636
637    // Bounds-check elimination in remainder
638    let scalar_start = remaining_start + remaining_chunks * 4;
639    for i in scalar_start..len {
640        total += *a.get_unchecked(i) * *b.get_unchecked(i);
641    }
642
643    total
644}
645
646/// NEON-optimized cosine similarity with 4x loop unrolling
647///
648/// # Safety
649/// Caller must ensure a.len() == b.len()
650#[cfg(target_arch = "aarch64")]
651#[inline(always)]
652unsafe fn cosine_similarity_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
653    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
654
655    let len = a.len();
656    let a_ptr = a.as_ptr();
657    let b_ptr = b.as_ptr();
658
659    let mut dot0 = vdupq_n_f32(0.0);
660    let mut dot1 = vdupq_n_f32(0.0);
661    let mut norm_a0 = vdupq_n_f32(0.0);
662    let mut norm_a1 = vdupq_n_f32(0.0);
663    let mut norm_b0 = vdupq_n_f32(0.0);
664    let mut norm_b1 = vdupq_n_f32(0.0);
665
666    let chunks = len / 8;
667    let mut idx = 0usize;
668
669    for _ in 0..chunks {
670        let va0 = vld1q_f32(a_ptr.add(idx));
671        let vb0 = vld1q_f32(b_ptr.add(idx));
672        dot0 = vfmaq_f32(dot0, va0, vb0);
673        norm_a0 = vfmaq_f32(norm_a0, va0, va0);
674        norm_b0 = vfmaq_f32(norm_b0, vb0, vb0);
675
676        let va1 = vld1q_f32(a_ptr.add(idx + 4));
677        let vb1 = vld1q_f32(b_ptr.add(idx + 4));
678        dot1 = vfmaq_f32(dot1, va1, vb1);
679        norm_a1 = vfmaq_f32(norm_a1, va1, va1);
680        norm_b1 = vfmaq_f32(norm_b1, vb1, vb1);
681
682        idx += 8;
683    }
684
685    // Tree reduction
686    let dot = vaddq_f32(dot0, dot1);
687    let norm_a = vaddq_f32(norm_a0, norm_a1);
688    let norm_b = vaddq_f32(norm_b0, norm_b1);
689
690    let mut dot_sum = vaddvq_f32(dot);
691    let mut norm_a_sum = vaddvq_f32(norm_a);
692    let mut norm_b_sum = vaddvq_f32(norm_b);
693
694    // Bounds-check elimination in remainder
695    for i in (chunks * 8)..len {
696        let ai = *a.get_unchecked(i);
697        let bi = *b.get_unchecked(i);
698        dot_sum += ai * bi;
699        norm_a_sum += ai * ai;
700        norm_b_sum += bi * bi;
701    }
702
703    dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
704}
705
706/// NEON-optimized Manhattan distance with 4x loop unrolling
707///
708/// # Safety
709/// Caller must ensure a.len() == b.len()
710#[cfg(target_arch = "aarch64")]
711#[inline(always)]
712unsafe fn manhattan_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
713    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
714
715    let len = a.len();
716    let a_ptr = a.as_ptr();
717    let b_ptr = b.as_ptr();
718
719    let mut sum0 = vdupq_n_f32(0.0);
720    let mut sum1 = vdupq_n_f32(0.0);
721    let mut sum2 = vdupq_n_f32(0.0);
722    let mut sum3 = vdupq_n_f32(0.0);
723
724    let chunks = len / 16;
725    let mut idx = 0usize;
726
727    for _ in 0..chunks {
728        // Use vabdq_f32 for absolute difference in one instruction
729        let va0 = vld1q_f32(a_ptr.add(idx));
730        let vb0 = vld1q_f32(b_ptr.add(idx));
731        sum0 = vaddq_f32(sum0, vabdq_f32(va0, vb0));
732
733        let va1 = vld1q_f32(a_ptr.add(idx + 4));
734        let vb1 = vld1q_f32(b_ptr.add(idx + 4));
735        sum1 = vaddq_f32(sum1, vabdq_f32(va1, vb1));
736
737        let va2 = vld1q_f32(a_ptr.add(idx + 8));
738        let vb2 = vld1q_f32(b_ptr.add(idx + 8));
739        sum2 = vaddq_f32(sum2, vabdq_f32(va2, vb2));
740
741        let va3 = vld1q_f32(a_ptr.add(idx + 12));
742        let vb3 = vld1q_f32(b_ptr.add(idx + 12));
743        sum3 = vaddq_f32(sum3, vabdq_f32(va3, vb3));
744
745        idx += 16;
746    }
747
748    // Tree reduction
749    let sum01 = vaddq_f32(sum0, sum1);
750    let sum23 = vaddq_f32(sum2, sum3);
751    let sum = vaddq_f32(sum01, sum23);
752
753    let remaining_start = chunks * 16;
754    let remaining_chunks = (len - remaining_start) / 4;
755    let mut final_sum = sum;
756
757    idx = remaining_start;
758    for _ in 0..remaining_chunks {
759        let va = vld1q_f32(a_ptr.add(idx));
760        let vb = vld1q_f32(b_ptr.add(idx));
761        final_sum = vaddq_f32(final_sum, vabdq_f32(va, vb));
762        idx += 4;
763    }
764
765    let mut total = vaddvq_f32(final_sum);
766
767    // Bounds-check elimination in remainder
768    let scalar_start = remaining_start + remaining_chunks * 4;
769    for i in scalar_start..len {
770        total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
771    }
772
773    total
774}
775
776// ============================================================================
777// Public API with architecture dispatch
778// ============================================================================
779
780/// SIMD-optimized dot product
781/// Uses AVX-512 > AVX2 on x86_64, NEON on ARM64/Apple Silicon
782#[inline(always)]
783pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
784    #[cfg(target_arch = "x86_64")]
785    {
786        if is_x86_feature_detected!("avx512f") {
787            unsafe { dot_product_avx512_impl(a, b) }
788        } else if is_x86_feature_detected!("avx2") {
789            unsafe { dot_product_avx2_impl(a, b) }
790        } else {
791            dot_product_scalar(a, b)
792        }
793    }
794
795    #[cfg(target_arch = "aarch64")]
796    {
797        if a.len() >= 64 {
798            unsafe { dot_product_neon_unrolled_impl(a, b) }
799        } else {
800            unsafe { dot_product_neon_impl(a, b) }
801        }
802    }
803
804    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
805    {
806        dot_product_scalar(a, b)
807    }
808}
809
810/// Legacy alias for backward compatibility
811#[inline(always)]
812pub fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
813    dot_product_simd(a, b)
814}
815
816#[cfg(target_arch = "x86_64")]
817#[target_feature(enable = "avx2")]
818unsafe fn dot_product_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
819    // SECURITY: Ensure both arrays have the same length to prevent out-of-bounds access
820    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
821
822    let len = a.len();
823    let mut sum = _mm256_setzero_ps();
824
825    let chunks = len / 8;
826    for i in 0..chunks {
827        let idx = i * 8;
828        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
829        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
830        let prod = _mm256_mul_ps(va, vb);
831        sum = _mm256_add_ps(sum, prod);
832    }
833
834    let sum_arr: [f32; 8] = std::mem::transmute(sum);
835    let mut total = sum_arr.iter().sum::<f32>();
836
837    for i in (chunks * 8)..len {
838        total += a[i] * b[i];
839    }
840
841    total
842}
843
844/// SIMD-optimized cosine similarity
845/// Uses AVX-512 > AVX2 on x86_64, NEON on ARM64/Apple Silicon
846#[inline(always)]
847pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
848    #[cfg(target_arch = "x86_64")]
849    {
850        if is_x86_feature_detected!("avx512f") {
851            unsafe { cosine_similarity_avx512_impl(a, b) }
852        } else if is_x86_feature_detected!("avx2") {
853            unsafe { cosine_similarity_avx2_impl(a, b) }
854        } else {
855            cosine_similarity_scalar(a, b)
856        }
857    }
858
859    #[cfg(target_arch = "aarch64")]
860    {
861        if a.len() >= 64 {
862            unsafe { cosine_similarity_neon_unrolled_impl(a, b) }
863        } else {
864            unsafe { cosine_similarity_neon_impl(a, b) }
865        }
866    }
867
868    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
869    {
870        cosine_similarity_scalar(a, b)
871    }
872}
873
874/// Legacy alias for backward compatibility
875#[inline(always)]
876pub fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
877    cosine_similarity_simd(a, b)
878}
879
880/// SIMD-optimized Manhattan distance
881/// Uses AVX-512 > AVX2 on x86_64, NEON on ARM64/Apple Silicon, scalar on other platforms
882#[inline(always)]
883pub fn manhattan_distance_simd(a: &[f32], b: &[f32]) -> f32 {
884    #[cfg(target_arch = "x86_64")]
885    {
886        if is_x86_feature_detected!("avx512f") {
887            unsafe { manhattan_distance_avx512_impl(a, b) }
888        } else if is_x86_feature_detected!("avx2") {
889            unsafe { manhattan_distance_avx2_impl(a, b) }
890        } else {
891            manhattan_distance_scalar(a, b)
892        }
893    }
894
895    #[cfg(target_arch = "aarch64")]
896    {
897        if a.len() >= 64 {
898            unsafe { manhattan_distance_neon_unrolled_impl(a, b) }
899        } else {
900            unsafe { manhattan_distance_neon_impl(a, b) }
901        }
902    }
903
904    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
905    {
906        manhattan_distance_scalar(a, b)
907    }
908}
909
910#[cfg(target_arch = "x86_64")]
911#[target_feature(enable = "avx2")]
912unsafe fn cosine_similarity_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
913    // SECURITY: Ensure both arrays have the same length to prevent out-of-bounds access
914    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
915
916    let len = a.len();
917    let mut dot = _mm256_setzero_ps();
918    let mut norm_a = _mm256_setzero_ps();
919    let mut norm_b = _mm256_setzero_ps();
920
921    let chunks = len / 8;
922    for i in 0..chunks {
923        let idx = i * 8;
924        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
925        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
926
927        // Dot product
928        dot = _mm256_add_ps(dot, _mm256_mul_ps(va, vb));
929
930        // Norms
931        norm_a = _mm256_add_ps(norm_a, _mm256_mul_ps(va, va));
932        norm_b = _mm256_add_ps(norm_b, _mm256_mul_ps(vb, vb));
933    }
934
935    let dot_arr: [f32; 8] = std::mem::transmute(dot);
936    let norm_a_arr: [f32; 8] = std::mem::transmute(norm_a);
937    let norm_b_arr: [f32; 8] = std::mem::transmute(norm_b);
938
939    let mut dot_sum = dot_arr.iter().sum::<f32>();
940    let mut norm_a_sum = norm_a_arr.iter().sum::<f32>();
941    let mut norm_b_sum = norm_b_arr.iter().sum::<f32>();
942
943    for i in (chunks * 8)..len {
944        dot_sum += a[i] * b[i];
945        norm_a_sum += a[i] * a[i];
946        norm_b_sum += b[i] * b[i];
947    }
948
949    dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
950}
951
952/// AVX2 Manhattan distance — processes 8 floats per iteration with absolute difference
953#[cfg(target_arch = "x86_64")]
954#[target_feature(enable = "avx2")]
955unsafe fn manhattan_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
956    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
957
958    let len = a.len();
959    // Use sign-bit mask for absolute value: clear the sign bit
960    let sign_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
961    let mut sum0 = _mm256_setzero_ps();
962    let mut sum1 = _mm256_setzero_ps();
963
964    // Process 16 floats at a time (2 x 8) for better ILP
965    let chunks = len / 16;
966    for i in 0..chunks {
967        let idx = i * 16;
968
969        let va0 = _mm256_loadu_ps(a.as_ptr().add(idx));
970        let vb0 = _mm256_loadu_ps(b.as_ptr().add(idx));
971        let diff0 = _mm256_sub_ps(va0, vb0);
972        let abs0 = _mm256_and_ps(diff0, sign_mask);
973        sum0 = _mm256_add_ps(sum0, abs0);
974
975        let va1 = _mm256_loadu_ps(a.as_ptr().add(idx + 8));
976        let vb1 = _mm256_loadu_ps(b.as_ptr().add(idx + 8));
977        let diff1 = _mm256_sub_ps(va1, vb1);
978        let abs1 = _mm256_and_ps(diff1, sign_mask);
979        sum1 = _mm256_add_ps(sum1, abs1);
980    }
981
982    let mut sum = _mm256_add_ps(sum0, sum1);
983
984    // Process remaining 8-float chunks
985    let remaining_start = chunks * 16;
986    let remaining_chunks = (len - remaining_start) / 8;
987    for i in 0..remaining_chunks {
988        let idx = remaining_start + i * 8;
989        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
990        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
991        let diff = _mm256_sub_ps(va, vb);
992        let abs_diff = _mm256_and_ps(diff, sign_mask);
993        sum = _mm256_add_ps(sum, abs_diff);
994    }
995
996    // Horizontal sum
997    let sum_arr: [f32; 8] = std::mem::transmute(sum);
998    let mut total = sum_arr.iter().sum::<f32>();
999
1000    // Handle remaining elements
1001    let scalar_start = remaining_start + remaining_chunks * 8;
1002    for i in scalar_start..len {
1003        total += (a[i] - b[i]).abs();
1004    }
1005
1006    total
1007}
1008
1009// Scalar fallback implementations
1010// These are kept for architectures without SIMD support
1011
1012#[allow(dead_code)]
1013fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
1014    a.iter()
1015        .zip(b.iter())
1016        .map(|(x, y)| {
1017            let diff = x - y;
1018            diff * diff
1019        })
1020        .sum::<f32>()
1021        .sqrt()
1022}
1023
1024#[allow(dead_code)]
1025fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
1026    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
1027}
1028
1029#[allow(dead_code)]
1030fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
1031    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
1032    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
1033    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
1034    dot / (norm_a * norm_b)
1035}
1036
1037#[allow(dead_code)]
1038fn manhattan_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
1039    a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
1040}
1041
1042// ============================================================================
1043// INT8 Quantized Operations
1044// ============================================================================
1045
1046/// SIMD-accelerated dot product for INT8 quantized vectors
1047/// Uses NEON vdotq_s32 on ARM64, AVX2 _mm256_maddubs_epi16 on x86_64
1048#[inline(always)]
1049pub fn dot_product_i8(a: &[i8], b: &[i8]) -> i32 {
1050    #[cfg(target_arch = "x86_64")]
1051    {
1052        if is_x86_feature_detected!("avx2") {
1053            unsafe { dot_product_i8_avx2_impl(a, b) }
1054        } else {
1055            dot_product_i8_scalar(a, b)
1056        }
1057    }
1058
1059    #[cfg(target_arch = "aarch64")]
1060    {
1061        unsafe { dot_product_i8_neon_impl(a, b) }
1062    }
1063
1064    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1065    {
1066        dot_product_i8_scalar(a, b)
1067    }
1068}
1069
1070/// SIMD-accelerated euclidean distance squared for INT8 quantized vectors
1071/// Returns squared distance (caller should sqrt if needed)
1072#[inline(always)]
1073pub fn euclidean_distance_squared_i8(a: &[i8], b: &[i8]) -> i32 {
1074    #[cfg(target_arch = "x86_64")]
1075    {
1076        if is_x86_feature_detected!("avx2") {
1077            unsafe { euclidean_distance_squared_i8_avx2_impl(a, b) }
1078        } else {
1079            euclidean_distance_squared_i8_scalar(a, b)
1080        }
1081    }
1082
1083    #[cfg(target_arch = "aarch64")]
1084    {
1085        unsafe { euclidean_distance_squared_i8_neon_impl(a, b) }
1086    }
1087
1088    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1089    {
1090        euclidean_distance_squared_i8_scalar(a, b)
1091    }
1092}
1093
1094/// NEON INT8 dot product using stable intrinsics
1095/// Note: Uses sign extension and multiply-add instead of vdotq_s32 for stability
1096///
1097/// # Safety
1098/// Caller must ensure a.len() == b.len()
1099#[cfg(target_arch = "aarch64")]
1100#[inline(always)]
1101unsafe fn dot_product_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
1102    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1103
1104    let len = a.len();
1105    let a_ptr = a.as_ptr();
1106    let b_ptr = b.as_ptr();
1107
1108    let mut sum = vdupq_n_s32(0);
1109
1110    // Process 8 int8s at a time (extend to i16, multiply, accumulate)
1111    let chunks = len / 8;
1112    let mut idx = 0usize;
1113
1114    for _ in 0..chunks {
1115        let va = vld1_s8(a_ptr.add(idx));
1116        let vb = vld1_s8(b_ptr.add(idx));
1117
1118        // Sign-extend to i16
1119        let va_i16 = vmovl_s8(va);
1120        let vb_i16 = vmovl_s8(vb);
1121
1122        // Multiply i16 * i16
1123        let prod_lo = vmull_s16(vget_low_s16(va_i16), vget_low_s16(vb_i16));
1124        let prod_hi = vmull_s16(vget_high_s16(va_i16), vget_high_s16(vb_i16));
1125
1126        // Accumulate
1127        sum = vaddq_s32(sum, prod_lo);
1128        sum = vaddq_s32(sum, prod_hi);
1129
1130        idx += 8;
1131    }
1132
1133    // Horizontal sum
1134    let mut total = vaddvq_s32(sum);
1135
1136    // Handle remaining elements with bounds-check elimination
1137    for i in (chunks * 8)..len {
1138        total += (*a.get_unchecked(i) as i32) * (*b.get_unchecked(i) as i32);
1139    }
1140
1141    total
1142}
1143
1144/// NEON INT8 euclidean distance squared using stable intrinsics
1145///
1146/// # Safety
1147/// Caller must ensure a.len() == b.len()
1148#[cfg(target_arch = "aarch64")]
1149#[inline(always)]
1150unsafe fn euclidean_distance_squared_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
1151    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1152
1153    let len = a.len();
1154    let a_ptr = a.as_ptr();
1155    let b_ptr = b.as_ptr();
1156
1157    let mut sum = vdupq_n_s32(0);
1158
1159    // Process 8 int8s at a time
1160    let chunks = len / 8;
1161    let mut idx = 0usize;
1162
1163    for _ in 0..chunks {
1164        let va = vld1_s8(a_ptr.add(idx));
1165        let vb = vld1_s8(b_ptr.add(idx));
1166
1167        // Sign-extend to i16
1168        let va_i16 = vmovl_s8(va);
1169        let vb_i16 = vmovl_s8(vb);
1170
1171        // Compute difference in i16
1172        let diff = vsubq_s16(va_i16, vb_i16);
1173
1174        // Square and accumulate: diff^2
1175        let prod_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
1176        let prod_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
1177
1178        sum = vaddq_s32(sum, prod_lo);
1179        sum = vaddq_s32(sum, prod_hi);
1180
1181        idx += 8;
1182    }
1183
1184    let mut total = vaddvq_s32(sum);
1185
1186    // Handle remaining elements with bounds-check elimination
1187    for i in (chunks * 8)..len {
1188        let diff = (*a.get_unchecked(i) as i32) - (*b.get_unchecked(i) as i32);
1189        total += diff * diff;
1190    }
1191
1192    total
1193}
1194
1195/// AVX2 INT8 dot product
1196#[cfg(target_arch = "x86_64")]
1197#[target_feature(enable = "avx2")]
1198unsafe fn dot_product_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
1199    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1200
1201    let len = a.len();
1202    let mut sum = _mm256_setzero_si256();
1203
1204    // Process 32 int8s at a time
1205    let chunks = len / 32;
1206    for i in 0..chunks {
1207        let idx = i * 32;
1208        let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
1209        let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
1210
1211        // For signed int8 multiply, we need to extend to i16 first
1212        let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
1213        let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
1214        let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
1215        let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
1216
1217        let prod_lo = _mm256_madd_epi16(va_lo, vb_lo);
1218        let prod_hi = _mm256_madd_epi16(va_hi, vb_hi);
1219
1220        sum = _mm256_add_epi32(sum, prod_lo);
1221        sum = _mm256_add_epi32(sum, prod_hi);
1222    }
1223
1224    // Horizontal sum
1225    let sum_arr: [i32; 8] = std::mem::transmute(sum);
1226    let mut total: i32 = sum_arr.iter().sum();
1227
1228    // Handle remaining elements
1229    for i in (chunks * 32)..len {
1230        total += (a[i] as i32) * (b[i] as i32);
1231    }
1232
1233    total
1234}
1235
1236/// AVX2 INT8 euclidean distance squared
1237#[cfg(target_arch = "x86_64")]
1238#[target_feature(enable = "avx2")]
1239unsafe fn euclidean_distance_squared_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
1240    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1241
1242    let len = a.len();
1243    let mut sum = _mm256_setzero_si256();
1244
1245    let chunks = len / 32;
1246    for i in 0..chunks {
1247        let idx = i * 32;
1248        let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
1249        let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
1250
1251        // Extend to i16, compute difference, then square
1252        let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
1253        let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
1254        let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
1255        let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
1256
1257        let diff_lo = _mm256_sub_epi16(va_lo, vb_lo);
1258        let diff_hi = _mm256_sub_epi16(va_hi, vb_hi);
1259
1260        let sq_lo = _mm256_madd_epi16(diff_lo, diff_lo);
1261        let sq_hi = _mm256_madd_epi16(diff_hi, diff_hi);
1262
1263        sum = _mm256_add_epi32(sum, sq_lo);
1264        sum = _mm256_add_epi32(sum, sq_hi);
1265    }
1266
1267    let sum_arr: [i32; 8] = std::mem::transmute(sum);
1268    let mut total: i32 = sum_arr.iter().sum();
1269
1270    for i in (chunks * 32)..len {
1271        let diff = (a[i] as i32) - (b[i] as i32);
1272        total += diff * diff;
1273    }
1274
1275    total
1276}
1277
1278/// Scalar fallback for INT8 dot product
1279#[allow(dead_code)]
1280fn dot_product_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
1281    a.iter()
1282        .zip(b.iter())
1283        .map(|(&x, &y)| (x as i32) * (y as i32))
1284        .sum()
1285}
1286
1287/// Scalar fallback for INT8 euclidean distance squared
1288#[allow(dead_code)]
1289fn euclidean_distance_squared_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
1290    a.iter()
1291        .zip(b.iter())
1292        .map(|(&x, &y)| {
1293            let diff = (x as i32) - (y as i32);
1294            diff * diff
1295        })
1296        .sum()
1297}
1298
1299// ============================================================================
1300// Batch Operations (Cache-optimized)
1301// ============================================================================
1302
1303/// Batch dot product - compute dot products of one query vector against multiple vectors
1304/// Returns results in the provided output slice
1305/// Optimized for cache locality by processing vectors in tiles
1306#[inline]
1307pub fn batch_dot_product(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1308    assert_eq!(
1309        vectors.len(),
1310        results.len(),
1311        "Output size must match vector count"
1312    );
1313
1314    // Process in tiles for better cache utilization
1315    const TILE_SIZE: usize = 16;
1316
1317    for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1318        let base_idx = chunk_idx * TILE_SIZE;
1319        for (i, vec) in chunk.iter().enumerate() {
1320            results[base_idx + i] = dot_product_simd(query, vec);
1321        }
1322    }
1323}
1324
1325/// Batch euclidean distance - compute distances from one query to multiple vectors
1326/// Returns results in the provided output slice
1327/// Optimized for cache locality
1328#[inline]
1329pub fn batch_euclidean(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1330    assert_eq!(
1331        vectors.len(),
1332        results.len(),
1333        "Output size must match vector count"
1334    );
1335
1336    const TILE_SIZE: usize = 16;
1337
1338    for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1339        let base_idx = chunk_idx * TILE_SIZE;
1340        for (i, vec) in chunk.iter().enumerate() {
1341            results[base_idx + i] = euclidean_distance_simd(query, vec);
1342        }
1343    }
1344}
1345
1346/// Batch cosine similarity - compute similarities from one query to multiple vectors
1347#[inline]
1348pub fn batch_cosine_similarity(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1349    assert_eq!(
1350        vectors.len(),
1351        results.len(),
1352        "Output size must match vector count"
1353    );
1354
1355    const TILE_SIZE: usize = 16;
1356
1357    for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1358        let base_idx = chunk_idx * TILE_SIZE;
1359        for (i, vec) in chunk.iter().enumerate() {
1360            results[base_idx + i] = cosine_similarity_simd(query, vec);
1361        }
1362    }
1363}
1364
1365/// Batch dot product with owned vectors (for convenience)
1366#[inline]
1367pub fn batch_dot_product_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
1368    let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1369    let mut results = vec![0.0; vectors.len()];
1370    batch_dot_product(query, &refs, &mut results);
1371    results
1372}
1373
1374/// Batch euclidean distance with owned vectors (for convenience)
1375#[inline]
1376pub fn batch_euclidean_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
1377    let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1378    let mut results = vec![0.0; vectors.len()];
1379    batch_euclidean(query, &refs, &mut results);
1380    results
1381}
1382
1383#[cfg(test)]
1384mod tests {
1385    use super::*;
1386
1387    #[test]
1388    fn test_euclidean_distance_simd() {
1389        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1390        let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
1391
1392        let result = euclidean_distance_simd(&a, &b);
1393        let expected = euclidean_distance_scalar(&a, &b);
1394
1395        assert!(
1396            (result - expected).abs() < 0.001,
1397            "SIMD result {} differs from scalar result {}",
1398            result,
1399            expected
1400        );
1401    }
1402
1403    #[test]
1404    fn test_euclidean_distance_large() {
1405        // Test with 128-dim vectors (common embedding size)
1406        let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
1407        let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
1408
1409        let result = euclidean_distance_simd(&a, &b);
1410        let expected = euclidean_distance_scalar(&a, &b);
1411
1412        assert!(
1413            (result - expected).abs() < 0.01,
1414            "Large vector: SIMD {} vs scalar {}",
1415            result,
1416            expected
1417        );
1418    }
1419
1420    #[test]
1421    fn test_dot_product_simd() {
1422        let a = vec![1.0; 16];
1423        let b = vec![2.0; 16];
1424
1425        let result = dot_product_simd(&a, &b);
1426        assert!((result - 32.0).abs() < 0.001);
1427    }
1428
1429    #[test]
1430    fn test_dot_product_large() {
1431        let a: Vec<f32> = (0..256).map(|i| (i % 10) as f32).collect();
1432        let b: Vec<f32> = (0..256).map(|i| ((i + 5) % 10) as f32).collect();
1433
1434        let result = dot_product_simd(&a, &b);
1435        let expected = dot_product_scalar(&a, &b);
1436
1437        assert!(
1438            (result - expected).abs() < 0.1,
1439            "Large dot product: SIMD {} vs scalar {}",
1440            result,
1441            expected
1442        );
1443    }
1444
1445    #[test]
1446    fn test_cosine_similarity_simd() {
1447        let a = vec![1.0, 0.0, 0.0];
1448        let b = vec![1.0, 0.0, 0.0];
1449
1450        let result = cosine_similarity_simd(&a, &b);
1451        assert!((result - 1.0).abs() < 0.001);
1452    }
1453
1454    #[test]
1455    fn test_cosine_similarity_orthogonal() {
1456        let a = vec![1.0, 0.0, 0.0, 0.0];
1457        let b = vec![0.0, 1.0, 0.0, 0.0];
1458
1459        let result = cosine_similarity_simd(&a, &b);
1460        assert!(
1461            result.abs() < 0.001,
1462            "Orthogonal vectors should have ~0 similarity, got {}",
1463            result
1464        );
1465    }
1466
1467    #[test]
1468    fn test_manhattan_distance_simd() {
1469        let a = vec![1.0, 2.0, 3.0, 4.0];
1470        let b = vec![5.0, 6.0, 7.0, 8.0];
1471
1472        let result = manhattan_distance_simd(&a, &b);
1473        let expected = manhattan_distance_scalar(&a, &b);
1474
1475        assert!(
1476            (result - expected).abs() < 0.001,
1477            "Manhattan: SIMD {} vs scalar {}",
1478            result,
1479            expected
1480        );
1481        assert!((result - 16.0).abs() < 0.001); // |4| + |4| + |4| + |4| = 16
1482    }
1483
1484    #[test]
1485    fn test_non_aligned_lengths() {
1486        // Test vectors not aligned to SIMD width (4 for NEON, 8 for AVX2)
1487        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]; // 7 elements
1488        let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1489
1490        let result = euclidean_distance_simd(&a, &b);
1491        let expected = euclidean_distance_scalar(&a, &b);
1492
1493        assert!(
1494            (result - expected).abs() < 0.001,
1495            "Non-aligned: SIMD {} vs scalar {}",
1496            result,
1497            expected
1498        );
1499    }
1500
1501    // Legacy function tests (ensure backward compatibility)
1502    #[test]
1503    fn test_legacy_avx2_aliases() {
1504        let a = vec![1.0, 2.0, 3.0, 4.0];
1505        let b = vec![5.0, 6.0, 7.0, 8.0];
1506
1507        // These should work identically to the _simd versions
1508        let _ = euclidean_distance_avx2(&a, &b);
1509        let _ = dot_product_avx2(&a, &b);
1510        let _ = cosine_similarity_avx2(&a, &b);
1511    }
1512
1513    // INT8 quantized operation tests
1514    #[test]
1515    fn test_dot_product_i8() {
1516        let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1517        let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
1518
1519        let result = dot_product_i8(&a, &b);
1520        let expected = dot_product_i8_scalar(&a, &b);
1521
1522        assert_eq!(
1523            result, expected,
1524            "INT8 dot product: SIMD {} vs scalar {}",
1525            result, expected
1526        );
1527    }
1528
1529    #[test]
1530    fn test_dot_product_i8_large() {
1531        // Test with 128 elements (common for quantized embeddings)
1532        let a: Vec<i8> = (0..128)
1533            .map(|i| ((i % 256) as i8).wrapping_sub(64))
1534            .collect();
1535        let b: Vec<i8> = (0..128)
1536            .map(|i| (((i + 10) % 256) as i8).wrapping_sub(64))
1537            .collect();
1538
1539        let result = dot_product_i8(&a, &b);
1540        let expected = dot_product_i8_scalar(&a, &b);
1541
1542        assert_eq!(
1543            result, expected,
1544            "Large INT8 dot product: SIMD {} vs scalar {}",
1545            result, expected
1546        );
1547    }
1548
1549    #[test]
1550    fn test_euclidean_distance_squared_i8() {
1551        let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1552        let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
1553
1554        let result = euclidean_distance_squared_i8(&a, &b);
1555        let expected = euclidean_distance_squared_i8_scalar(&a, &b);
1556
1557        assert_eq!(
1558            result, expected,
1559            "INT8 euclidean^2: SIMD {} vs scalar {}",
1560            result, expected
1561        );
1562        // Each diff is 1, so 16 diffs squared = 16
1563        assert_eq!(result, 16, "Expected 16, got {}", result);
1564    }
1565
1566    #[test]
1567    fn test_euclidean_distance_squared_i8_large() {
1568        let a: Vec<i8> = (0..128)
1569            .map(|i| ((i % 256) as i8).wrapping_sub(64))
1570            .collect();
1571        let b: Vec<i8> = (0..128)
1572            .map(|i| (((i + 5) % 256) as i8).wrapping_sub(64))
1573            .collect();
1574
1575        let result = euclidean_distance_squared_i8(&a, &b);
1576        let expected = euclidean_distance_squared_i8_scalar(&a, &b);
1577
1578        assert_eq!(
1579            result, expected,
1580            "Large INT8 euclidean^2: SIMD {} vs scalar {}",
1581            result, expected
1582        );
1583    }
1584
1585    // Batch operation tests
1586    #[test]
1587    fn test_batch_dot_product() {
1588        let query = vec![1.0, 2.0, 3.0, 4.0];
1589        let v1 = vec![1.0, 0.0, 0.0, 0.0];
1590        let v2 = vec![0.0, 1.0, 0.0, 0.0];
1591        let v3 = vec![0.0, 0.0, 1.0, 0.0];
1592        let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
1593        let mut results = vec![0.0; 3];
1594
1595        batch_dot_product(&query, &vectors, &mut results);
1596
1597        assert!((results[0] - 1.0).abs() < 0.001);
1598        assert!((results[1] - 2.0).abs() < 0.001);
1599        assert!((results[2] - 3.0).abs() < 0.001);
1600    }
1601
1602    #[test]
1603    fn test_batch_euclidean() {
1604        let query = vec![0.0, 0.0, 0.0, 0.0];
1605        let v1 = vec![3.0, 4.0, 0.0, 0.0];
1606        let v2 = vec![0.0, 0.0, 5.0, 12.0];
1607        let vectors: Vec<&[f32]> = vec![&v1, &v2];
1608        let mut results = vec![0.0; 2];
1609
1610        batch_euclidean(&query, &vectors, &mut results);
1611
1612        assert!(
1613            (results[0] - 5.0).abs() < 0.001,
1614            "Expected 5.0, got {}",
1615            results[0]
1616        );
1617        assert!(
1618            (results[1] - 13.0).abs() < 0.001,
1619            "Expected 13.0, got {}",
1620            results[1]
1621        );
1622    }
1623
1624    #[test]
1625    fn test_batch_cosine_similarity() {
1626        let query = vec![1.0, 0.0, 0.0, 0.0];
1627        let v1 = vec![1.0, 0.0, 0.0, 0.0]; // Same direction
1628        let v2 = vec![0.0, 1.0, 0.0, 0.0]; // Orthogonal
1629        let v3 = vec![-1.0, 0.0, 0.0, 0.0]; // Opposite
1630        let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
1631        let mut results = vec![0.0; 3];
1632
1633        batch_cosine_similarity(&query, &vectors, &mut results);
1634
1635        assert!(
1636            (results[0] - 1.0).abs() < 0.001,
1637            "Same direction should be 1.0"
1638        );
1639        assert!(results[1].abs() < 0.001, "Orthogonal should be 0.0");
1640        assert!((results[2] + 1.0).abs() < 0.001, "Opposite should be -1.0");
1641    }
1642
1643    #[test]
1644    fn test_batch_owned_convenience() {
1645        let query = vec![1.0, 2.0, 3.0, 4.0];
1646        let vectors = vec![vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]];
1647
1648        let results = batch_dot_product_owned(&query, &vectors);
1649        assert_eq!(results.len(), 2);
1650        assert!((results[0] - 1.0).abs() < 0.001);
1651        assert!((results[1] - 2.0).abs() < 0.001);
1652    }
1653
1654    #[test]
1655    fn test_unrolled_vs_non_unrolled_consistency() {
1656        // Test that unrolled and non-unrolled implementations produce same results
1657        let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
1658        let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
1659
1660        let result = euclidean_distance_simd(&a, &b);
1661        let expected = euclidean_distance_scalar(&a, &b);
1662
1663        assert!(
1664            (result - expected).abs() < 0.01,
1665            "Unrolled consistency: SIMD {} vs scalar {}",
1666            result,
1667            expected
1668        );
1669    }
1670}