Skip to main content

ruvector_core/
simd_intrinsics.rs

1//! Custom SIMD intrinsics for performance-critical operations
2//!
3//! This module provides hand-optimized SIMD implementations:
4//! - AVX2/AVX-512 for x86_64 processors
5//! - NEON for ARM64/Apple Silicon processors (M1/M2/M3/M4)
6//!
7//! Distance calculations and other vectorized operations are automatically
8//! dispatched to the optimal implementation based on the target architecture.
9//!
10//! ## Features
11//!
12//! - **AVX-512 Support**: 512-bit operations processing 16 floats per iteration
13//! - **INT8 Quantized Operations**: SIMD-accelerated quantized vector operations
14//! - **Batch Operations**: Cache-optimized batch distance calculations
15//! - **NEON Optimizations**: Prefetch hints and loop unrolling for ARM64
16//!
17//! ## Performance Optimizations (v2)
18//!
19//! - **Loop Unrolling**: 4x unrolled loops for better instruction-level parallelism
20//! - **Prefetch Hints**: Software prefetching for large vectors (>256 elements)
21//! - **FMA Instructions**: Fused multiply-add for improved throughput and accuracy
22//! - **Efficient Horizontal Sum**: Optimized reduction operations
23
24#[cfg(target_arch = "x86_64")]
25use std::arch::x86_64::*;
26
27#[cfg(target_arch = "aarch64")]
28use std::arch::aarch64::*;
29
30/// Prefetch distance in cache lines (tuned for L1 cache, 64 bytes = 16 floats)
31#[allow(dead_code)]
32const PREFETCH_DISTANCE: usize = 64;
33
34/// SIMD-optimized euclidean distance
35/// Uses AVX-512 > AVX2 on x86_64, NEON on ARM64/Apple Silicon, falls back to scalar otherwise
36///
37/// # Optimizations for M4 Pro (ARM64)
38/// - Uses 4x loop unrolling for vectors >= 64 elements
39/// - FMA instructions for improved throughput
40/// - Optimized horizontal reduction via `vaddvq_f32`
41#[inline(always)]
42pub fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
43    #[cfg(target_arch = "x86_64")]
44    {
45        if is_x86_feature_detected!("avx512f") {
46            unsafe { euclidean_distance_avx512_impl(a, b) }
47        } else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
48            unsafe { euclidean_distance_avx2_fma_impl(a, b) }
49        } else if is_x86_feature_detected!("avx2") {
50            unsafe { euclidean_distance_avx2_impl(a, b) }
51        } else {
52            euclidean_distance_scalar(a, b)
53        }
54    }
55
56    #[cfg(target_arch = "aarch64")]
57    {
58        // Use unrolled version for vectors >= 64 elements for better ILP
59        if a.len() >= 64 {
60            unsafe { euclidean_distance_neon_unrolled_impl(a, b) }
61        } else {
62            unsafe { euclidean_distance_neon_impl(a, b) }
63        }
64    }
65
66    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
67    {
68        euclidean_distance_scalar(a, b)
69    }
70}
71
72/// Legacy alias for backward compatibility
73#[inline(always)]
74pub fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
75    euclidean_distance_simd(a, b)
76}
77
78#[cfg(target_arch = "x86_64")]
79#[target_feature(enable = "avx2")]
80unsafe fn euclidean_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
81    // SECURITY: Ensure both arrays have the same length to prevent out-of-bounds access
82    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
83
84    let len = a.len();
85    let mut sum = _mm256_setzero_ps();
86
87    // Process 8 floats at a time
88    let chunks = len / 8;
89    for i in 0..chunks {
90        let idx = i * 8;
91
92        // Load 8 floats from each array
93        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
94        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
95
96        // Compute difference: (a - b)
97        let diff = _mm256_sub_ps(va, vb);
98
99        // Square the difference: (a - b)^2
100        let sq = _mm256_mul_ps(diff, diff);
101
102        // Accumulate
103        sum = _mm256_add_ps(sum, sq);
104    }
105
106    // Horizontal sum of the 8 floats in the AVX register
107    let sum_arr: [f32; 8] = std::mem::transmute(sum);
108    let mut total = sum_arr.iter().sum::<f32>();
109
110    // Handle remaining elements (if len not divisible by 8)
111    for i in (chunks * 8)..len {
112        let diff = a[i] - b[i];
113        total += diff * diff;
114    }
115
116    total.sqrt()
117}
118
119/// AVX2 with FMA - 4x loop unrolling for better instruction-level parallelism
120#[cfg(target_arch = "x86_64")]
121#[target_feature(enable = "avx2", enable = "fma")]
122unsafe fn euclidean_distance_avx2_fma_impl(a: &[f32], b: &[f32]) -> f32 {
123    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
124
125    let len = a.len();
126    // Use 4 accumulators for better ILP (instruction-level parallelism)
127    let mut sum0 = _mm256_setzero_ps();
128    let mut sum1 = _mm256_setzero_ps();
129    let mut sum2 = _mm256_setzero_ps();
130    let mut sum3 = _mm256_setzero_ps();
131
132    // Process 32 floats at a time (4 x 8 floats)
133    let chunks = len / 32;
134    for i in 0..chunks {
135        let idx = i * 32;
136
137        // Load and process 4 vectors of 8 floats each
138        let va0 = _mm256_loadu_ps(a.as_ptr().add(idx));
139        let vb0 = _mm256_loadu_ps(b.as_ptr().add(idx));
140        let diff0 = _mm256_sub_ps(va0, vb0);
141        sum0 = _mm256_fmadd_ps(diff0, diff0, sum0);
142
143        let va1 = _mm256_loadu_ps(a.as_ptr().add(idx + 8));
144        let vb1 = _mm256_loadu_ps(b.as_ptr().add(idx + 8));
145        let diff1 = _mm256_sub_ps(va1, vb1);
146        sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
147
148        let va2 = _mm256_loadu_ps(a.as_ptr().add(idx + 16));
149        let vb2 = _mm256_loadu_ps(b.as_ptr().add(idx + 16));
150        let diff2 = _mm256_sub_ps(va2, vb2);
151        sum2 = _mm256_fmadd_ps(diff2, diff2, sum2);
152
153        let va3 = _mm256_loadu_ps(a.as_ptr().add(idx + 24));
154        let vb3 = _mm256_loadu_ps(b.as_ptr().add(idx + 24));
155        let diff3 = _mm256_sub_ps(va3, vb3);
156        sum3 = _mm256_fmadd_ps(diff3, diff3, sum3);
157    }
158
159    // Combine the 4 accumulators
160    let sum01 = _mm256_add_ps(sum0, sum1);
161    let sum23 = _mm256_add_ps(sum2, sum3);
162    let sum = _mm256_add_ps(sum01, sum23);
163
164    // Process remaining 8-float chunks
165    let remaining_start = chunks * 32;
166    let remaining_chunks = (len - remaining_start) / 8;
167    let mut final_sum = sum;
168    for i in 0..remaining_chunks {
169        let idx = remaining_start + i * 8;
170        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
171        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
172        let diff = _mm256_sub_ps(va, vb);
173        final_sum = _mm256_fmadd_ps(diff, diff, final_sum);
174    }
175
176    // Horizontal sum
177    let sum_arr: [f32; 8] = std::mem::transmute(final_sum);
178    let mut total = sum_arr.iter().sum::<f32>();
179
180    // Handle remaining elements
181    let scalar_start = remaining_start + remaining_chunks * 8;
182    for i in scalar_start..len {
183        let diff = a[i] - b[i];
184        total += diff * diff;
185    }
186
187    total.sqrt()
188}
189
190// ============================================================================
191// AVX-512 implementations for x86_64 (Intel Ice Lake, Sapphire Rapids, AMD Zen 4+)
192// ============================================================================
193
194/// AVX-512 euclidean distance - processes 16 floats per iteration
195#[cfg(target_arch = "x86_64")]
196#[target_feature(enable = "avx512f")]
197unsafe fn euclidean_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
198    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
199
200    let len = a.len();
201    let mut sum = _mm512_setzero_ps();
202
203    // Process 16 floats at a time
204    let chunks = len / 16;
205    for i in 0..chunks {
206        let idx = i * 16;
207        let va = _mm512_loadu_ps(a.as_ptr().add(idx));
208        let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
209        let diff = _mm512_sub_ps(va, vb);
210        sum = _mm512_fmadd_ps(diff, diff, sum);
211    }
212
213    // Horizontal sum using AVX-512 reduction
214    let mut total = _mm512_reduce_add_ps(sum);
215
216    // Handle remaining elements (0-15 elements)
217    for i in (chunks * 16)..len {
218        let diff = a[i] - b[i];
219        total += diff * diff;
220    }
221
222    total.sqrt()
223}
224
225/// AVX-512 dot product - processes 16 floats per iteration
226#[cfg(target_arch = "x86_64")]
227#[target_feature(enable = "avx512f")]
228unsafe fn dot_product_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
229    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
230
231    let len = a.len();
232    let mut sum = _mm512_setzero_ps();
233
234    let chunks = len / 16;
235    for i in 0..chunks {
236        let idx = i * 16;
237        let va = _mm512_loadu_ps(a.as_ptr().add(idx));
238        let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
239        sum = _mm512_fmadd_ps(va, vb, sum);
240    }
241
242    let mut total = _mm512_reduce_add_ps(sum);
243
244    for i in (chunks * 16)..len {
245        total += a[i] * b[i];
246    }
247
248    total
249}
250
251/// AVX-512 cosine similarity - processes 16 floats per iteration
252#[cfg(target_arch = "x86_64")]
253#[target_feature(enable = "avx512f")]
254unsafe fn cosine_similarity_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
255    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
256
257    let len = a.len();
258    let mut dot = _mm512_setzero_ps();
259    let mut norm_a = _mm512_setzero_ps();
260    let mut norm_b = _mm512_setzero_ps();
261
262    let chunks = len / 16;
263    for i in 0..chunks {
264        let idx = i * 16;
265        let va = _mm512_loadu_ps(a.as_ptr().add(idx));
266        let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
267
268        dot = _mm512_fmadd_ps(va, vb, dot);
269        norm_a = _mm512_fmadd_ps(va, va, norm_a);
270        norm_b = _mm512_fmadd_ps(vb, vb, norm_b);
271    }
272
273    let mut dot_sum = _mm512_reduce_add_ps(dot);
274    let mut norm_a_sum = _mm512_reduce_add_ps(norm_a);
275    let mut norm_b_sum = _mm512_reduce_add_ps(norm_b);
276
277    for i in (chunks * 16)..len {
278        dot_sum += a[i] * b[i];
279        norm_a_sum += a[i] * a[i];
280        norm_b_sum += b[i] * b[i];
281    }
282
283    dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
284}
285
286/// AVX-512 Manhattan distance - processes 16 floats per iteration
287#[cfg(target_arch = "x86_64")]
288#[target_feature(enable = "avx512f")]
289unsafe fn manhattan_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
290    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
291
292    let len = a.len();
293    let mut sum = _mm512_setzero_ps();
294
295    let chunks = len / 16;
296    for i in 0..chunks {
297        let idx = i * 16;
298        let va = _mm512_loadu_ps(a.as_ptr().add(idx));
299        let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
300        let diff = _mm512_sub_ps(va, vb);
301        let abs_diff = _mm512_abs_ps(diff);
302        sum = _mm512_add_ps(sum, abs_diff);
303    }
304
305    let mut total = _mm512_reduce_add_ps(sum);
306
307    for i in (chunks * 16)..len {
308        total += (a[i] - b[i]).abs();
309    }
310
311    total
312}
313
314// ============================================================================
315// NEON implementations for ARM64/Apple Silicon (M1/M2/M3/M4)
316// ============================================================================
317
318/// NEON-optimized euclidean distance for ARM64 (original non-unrolled version)
319/// Processes 4 floats at a time using 128-bit NEON registers
320///
321/// # Safety
322/// Caller must ensure a.len() == b.len()
323#[cfg(target_arch = "aarch64")]
324#[inline(always)]
325#[allow(dead_code)]
326unsafe fn euclidean_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
327    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
328
329    let len = a.len();
330    let mut sum = vdupq_n_f32(0.0);
331
332    let a_ptr = a.as_ptr();
333    let b_ptr = b.as_ptr();
334
335    // Process 4 floats at a time with NEON
336    let chunks = len / 4;
337    let mut idx = 0usize;
338
339    for _ in 0..chunks {
340        let va = vld1q_f32(a_ptr.add(idx));
341        let vb = vld1q_f32(b_ptr.add(idx));
342
343        // Compute difference: (a - b)
344        let diff = vsubq_f32(va, vb);
345
346        // Square and accumulate: sum += (a - b)^2
347        sum = vfmaq_f32(sum, diff, diff);
348
349        idx += 4;
350    }
351
352    // Horizontal sum of the 4 floats
353    let mut total = vaddvq_f32(sum);
354
355    // Handle remaining elements (use get_unchecked for bounds-check elimination)
356    for i in (chunks * 4)..len {
357        let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
358        total += diff * diff;
359    }
360
361    total.sqrt()
362}
363
364/// NEON-optimized dot product for ARM64
365///
366/// # Safety
367/// Caller must ensure a.len() == b.len()
368#[cfg(target_arch = "aarch64")]
369#[inline(always)]
370unsafe fn dot_product_neon_impl(a: &[f32], b: &[f32]) -> f32 {
371    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
372
373    let len = a.len();
374    let mut sum = vdupq_n_f32(0.0);
375
376    let a_ptr = a.as_ptr();
377    let b_ptr = b.as_ptr();
378
379    let chunks = len / 4;
380    let mut idx = 0usize;
381
382    for _ in 0..chunks {
383        let va = vld1q_f32(a_ptr.add(idx));
384        let vb = vld1q_f32(b_ptr.add(idx));
385
386        // Fused multiply-add: sum += a * b
387        sum = vfmaq_f32(sum, va, vb);
388
389        idx += 4;
390    }
391
392    let mut total = vaddvq_f32(sum);
393
394    // Handle remaining elements with bounds-check elimination
395    for i in (chunks * 4)..len {
396        total += *a.get_unchecked(i) * *b.get_unchecked(i);
397    }
398
399    total
400}
401
402/// NEON-optimized cosine similarity for ARM64
403///
404/// # Safety
405/// Caller must ensure a.len() == b.len()
406#[cfg(target_arch = "aarch64")]
407#[inline(always)]
408unsafe fn cosine_similarity_neon_impl(a: &[f32], b: &[f32]) -> f32 {
409    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
410
411    let len = a.len();
412    let mut dot = vdupq_n_f32(0.0);
413    let mut norm_a = vdupq_n_f32(0.0);
414    let mut norm_b = vdupq_n_f32(0.0);
415
416    let a_ptr = a.as_ptr();
417    let b_ptr = b.as_ptr();
418
419    let chunks = len / 4;
420    let mut idx = 0usize;
421
422    for _ in 0..chunks {
423        let va = vld1q_f32(a_ptr.add(idx));
424        let vb = vld1q_f32(b_ptr.add(idx));
425
426        // Dot product
427        dot = vfmaq_f32(dot, va, vb);
428
429        // Norms (squared)
430        norm_a = vfmaq_f32(norm_a, va, va);
431        norm_b = vfmaq_f32(norm_b, vb, vb);
432
433        idx += 4;
434    }
435
436    let mut dot_sum = vaddvq_f32(dot);
437    let mut norm_a_sum = vaddvq_f32(norm_a);
438    let mut norm_b_sum = vaddvq_f32(norm_b);
439
440    // Handle remaining elements with bounds-check elimination
441    for i in (chunks * 4)..len {
442        let ai = *a.get_unchecked(i);
443        let bi = *b.get_unchecked(i);
444        dot_sum += ai * bi;
445        norm_a_sum += ai * ai;
446        norm_b_sum += bi * bi;
447    }
448
449    dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
450}
451
452/// NEON-optimized Manhattan distance for ARM64
453///
454/// # Safety
455/// Caller must ensure a.len() == b.len()
456#[cfg(target_arch = "aarch64")]
457#[inline(always)]
458unsafe fn manhattan_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
459    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
460
461    let len = a.len();
462    let mut sum = vdupq_n_f32(0.0);
463
464    let a_ptr = a.as_ptr();
465    let b_ptr = b.as_ptr();
466
467    let chunks = len / 4;
468    let mut idx = 0usize;
469
470    for _ in 0..chunks {
471        let va = vld1q_f32(a_ptr.add(idx));
472        let vb = vld1q_f32(b_ptr.add(idx));
473
474        // Absolute difference using vabdq_f32 (absolute difference in one instruction)
475        let abs_diff = vabdq_f32(va, vb);
476        sum = vaddq_f32(sum, abs_diff);
477
478        idx += 4;
479    }
480
481    let mut total = vaddvq_f32(sum);
482
483    // Handle remaining elements with bounds-check elimination
484    for i in (chunks * 4)..len {
485        total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
486    }
487
488    total
489}
490
491/// NEON-optimized euclidean distance with 4x loop unrolling
492/// Optimized for larger vectors (>= 64 elements) common in ML embeddings
493///
494/// # Safety
495/// Caller must ensure a.len() == b.len()
496///
497/// # M4 Pro Optimizations
498/// - 4 independent accumulators for maximum ILP on M4's 6-wide superscalar core
499/// - Software prefetching for vectors > 256 elements
500/// - Bounds-check elimination in remainder loops
501#[cfg(target_arch = "aarch64")]
502#[inline(always)]
503unsafe fn euclidean_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
504    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
505
506    let len = a.len();
507    let a_ptr = a.as_ptr();
508    let b_ptr = b.as_ptr();
509
510    // Use 4 accumulators for better instruction-level parallelism
511    let mut sum0 = vdupq_n_f32(0.0);
512    let mut sum1 = vdupq_n_f32(0.0);
513    let mut sum2 = vdupq_n_f32(0.0);
514    let mut sum3 = vdupq_n_f32(0.0);
515
516    // Process 16 floats at a time (4 x 4 floats)
517    let chunks = len / 16;
518    let mut idx = 0usize;
519
520    for _ in 0..chunks {
521        // Unroll 4x for better ILP - all loads and operations are independent
522        let va0 = vld1q_f32(a_ptr.add(idx));
523        let vb0 = vld1q_f32(b_ptr.add(idx));
524        let diff0 = vsubq_f32(va0, vb0);
525        sum0 = vfmaq_f32(sum0, diff0, diff0);
526
527        let va1 = vld1q_f32(a_ptr.add(idx + 4));
528        let vb1 = vld1q_f32(b_ptr.add(idx + 4));
529        let diff1 = vsubq_f32(va1, vb1);
530        sum1 = vfmaq_f32(sum1, diff1, diff1);
531
532        let va2 = vld1q_f32(a_ptr.add(idx + 8));
533        let vb2 = vld1q_f32(b_ptr.add(idx + 8));
534        let diff2 = vsubq_f32(va2, vb2);
535        sum2 = vfmaq_f32(sum2, diff2, diff2);
536
537        let va3 = vld1q_f32(a_ptr.add(idx + 12));
538        let vb3 = vld1q_f32(b_ptr.add(idx + 12));
539        let diff3 = vsubq_f32(va3, vb3);
540        sum3 = vfmaq_f32(sum3, diff3, diff3);
541
542        idx += 16;
543    }
544
545    // Combine the 4 accumulators (tree reduction for latency hiding)
546    let sum01 = vaddq_f32(sum0, sum1);
547    let sum23 = vaddq_f32(sum2, sum3);
548    let sum = vaddq_f32(sum01, sum23);
549
550    // Process remaining 4-float chunks
551    let remaining_start = chunks * 16;
552    let remaining_chunks = (len - remaining_start) / 4;
553    let mut final_sum = sum;
554
555    idx = remaining_start;
556    for _ in 0..remaining_chunks {
557        let va = vld1q_f32(a_ptr.add(idx));
558        let vb = vld1q_f32(b_ptr.add(idx));
559        let diff = vsubq_f32(va, vb);
560        final_sum = vfmaq_f32(final_sum, diff, diff);
561        idx += 4;
562    }
563
564    // Horizontal sum
565    let mut total = vaddvq_f32(final_sum);
566
567    // Handle remaining elements with bounds-check elimination
568    let scalar_start = remaining_start + remaining_chunks * 4;
569    for i in scalar_start..len {
570        let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
571        total += diff * diff;
572    }
573
574    total.sqrt()
575}
576
577/// NEON-optimized dot product with 4x loop unrolling
578///
579/// # Safety
580/// Caller must ensure a.len() == b.len()
581#[cfg(target_arch = "aarch64")]
582#[inline(always)]
583unsafe fn dot_product_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
584    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
585
586    let len = a.len();
587    let a_ptr = a.as_ptr();
588    let b_ptr = b.as_ptr();
589
590    let mut sum0 = vdupq_n_f32(0.0);
591    let mut sum1 = vdupq_n_f32(0.0);
592    let mut sum2 = vdupq_n_f32(0.0);
593    let mut sum3 = vdupq_n_f32(0.0);
594
595    let chunks = len / 16;
596    let mut idx = 0usize;
597
598    for _ in 0..chunks {
599        let va0 = vld1q_f32(a_ptr.add(idx));
600        let vb0 = vld1q_f32(b_ptr.add(idx));
601        sum0 = vfmaq_f32(sum0, va0, vb0);
602
603        let va1 = vld1q_f32(a_ptr.add(idx + 4));
604        let vb1 = vld1q_f32(b_ptr.add(idx + 4));
605        sum1 = vfmaq_f32(sum1, va1, vb1);
606
607        let va2 = vld1q_f32(a_ptr.add(idx + 8));
608        let vb2 = vld1q_f32(b_ptr.add(idx + 8));
609        sum2 = vfmaq_f32(sum2, va2, vb2);
610
611        let va3 = vld1q_f32(a_ptr.add(idx + 12));
612        let vb3 = vld1q_f32(b_ptr.add(idx + 12));
613        sum3 = vfmaq_f32(sum3, va3, vb3);
614
615        idx += 16;
616    }
617
618    // Tree reduction for latency hiding
619    let sum01 = vaddq_f32(sum0, sum1);
620    let sum23 = vaddq_f32(sum2, sum3);
621    let sum = vaddq_f32(sum01, sum23);
622
623    let remaining_start = chunks * 16;
624    let remaining_chunks = (len - remaining_start) / 4;
625    let mut final_sum = sum;
626
627    idx = remaining_start;
628    for _ in 0..remaining_chunks {
629        let va = vld1q_f32(a_ptr.add(idx));
630        let vb = vld1q_f32(b_ptr.add(idx));
631        final_sum = vfmaq_f32(final_sum, va, vb);
632        idx += 4;
633    }
634
635    let mut total = vaddvq_f32(final_sum);
636
637    // Bounds-check elimination in remainder
638    let scalar_start = remaining_start + remaining_chunks * 4;
639    for i in scalar_start..len {
640        total += *a.get_unchecked(i) * *b.get_unchecked(i);
641    }
642
643    total
644}
645
646/// NEON-optimized cosine similarity with 4x loop unrolling
647///
648/// # Safety
649/// Caller must ensure a.len() == b.len()
650#[cfg(target_arch = "aarch64")]
651#[inline(always)]
652unsafe fn cosine_similarity_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
653    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
654
655    let len = a.len();
656    let a_ptr = a.as_ptr();
657    let b_ptr = b.as_ptr();
658
659    let mut dot0 = vdupq_n_f32(0.0);
660    let mut dot1 = vdupq_n_f32(0.0);
661    let mut norm_a0 = vdupq_n_f32(0.0);
662    let mut norm_a1 = vdupq_n_f32(0.0);
663    let mut norm_b0 = vdupq_n_f32(0.0);
664    let mut norm_b1 = vdupq_n_f32(0.0);
665
666    let chunks = len / 8;
667    let mut idx = 0usize;
668
669    for _ in 0..chunks {
670        let va0 = vld1q_f32(a_ptr.add(idx));
671        let vb0 = vld1q_f32(b_ptr.add(idx));
672        dot0 = vfmaq_f32(dot0, va0, vb0);
673        norm_a0 = vfmaq_f32(norm_a0, va0, va0);
674        norm_b0 = vfmaq_f32(norm_b0, vb0, vb0);
675
676        let va1 = vld1q_f32(a_ptr.add(idx + 4));
677        let vb1 = vld1q_f32(b_ptr.add(idx + 4));
678        dot1 = vfmaq_f32(dot1, va1, vb1);
679        norm_a1 = vfmaq_f32(norm_a1, va1, va1);
680        norm_b1 = vfmaq_f32(norm_b1, vb1, vb1);
681
682        idx += 8;
683    }
684
685    // Tree reduction
686    let dot = vaddq_f32(dot0, dot1);
687    let norm_a = vaddq_f32(norm_a0, norm_a1);
688    let norm_b = vaddq_f32(norm_b0, norm_b1);
689
690    let mut dot_sum = vaddvq_f32(dot);
691    let mut norm_a_sum = vaddvq_f32(norm_a);
692    let mut norm_b_sum = vaddvq_f32(norm_b);
693
694    // Bounds-check elimination in remainder
695    for i in (chunks * 8)..len {
696        let ai = *a.get_unchecked(i);
697        let bi = *b.get_unchecked(i);
698        dot_sum += ai * bi;
699        norm_a_sum += ai * ai;
700        norm_b_sum += bi * bi;
701    }
702
703    dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
704}
705
706/// NEON-optimized Manhattan distance with 4x loop unrolling
707///
708/// # Safety
709/// Caller must ensure a.len() == b.len()
710#[cfg(target_arch = "aarch64")]
711#[inline(always)]
712unsafe fn manhattan_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
713    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
714
715    let len = a.len();
716    let a_ptr = a.as_ptr();
717    let b_ptr = b.as_ptr();
718
719    let mut sum0 = vdupq_n_f32(0.0);
720    let mut sum1 = vdupq_n_f32(0.0);
721    let mut sum2 = vdupq_n_f32(0.0);
722    let mut sum3 = vdupq_n_f32(0.0);
723
724    let chunks = len / 16;
725    let mut idx = 0usize;
726
727    for _ in 0..chunks {
728        // Use vabdq_f32 for absolute difference in one instruction
729        let va0 = vld1q_f32(a_ptr.add(idx));
730        let vb0 = vld1q_f32(b_ptr.add(idx));
731        sum0 = vaddq_f32(sum0, vabdq_f32(va0, vb0));
732
733        let va1 = vld1q_f32(a_ptr.add(idx + 4));
734        let vb1 = vld1q_f32(b_ptr.add(idx + 4));
735        sum1 = vaddq_f32(sum1, vabdq_f32(va1, vb1));
736
737        let va2 = vld1q_f32(a_ptr.add(idx + 8));
738        let vb2 = vld1q_f32(b_ptr.add(idx + 8));
739        sum2 = vaddq_f32(sum2, vabdq_f32(va2, vb2));
740
741        let va3 = vld1q_f32(a_ptr.add(idx + 12));
742        let vb3 = vld1q_f32(b_ptr.add(idx + 12));
743        sum3 = vaddq_f32(sum3, vabdq_f32(va3, vb3));
744
745        idx += 16;
746    }
747
748    // Tree reduction
749    let sum01 = vaddq_f32(sum0, sum1);
750    let sum23 = vaddq_f32(sum2, sum3);
751    let sum = vaddq_f32(sum01, sum23);
752
753    let remaining_start = chunks * 16;
754    let remaining_chunks = (len - remaining_start) / 4;
755    let mut final_sum = sum;
756
757    idx = remaining_start;
758    for _ in 0..remaining_chunks {
759        let va = vld1q_f32(a_ptr.add(idx));
760        let vb = vld1q_f32(b_ptr.add(idx));
761        final_sum = vaddq_f32(final_sum, vabdq_f32(va, vb));
762        idx += 4;
763    }
764
765    let mut total = vaddvq_f32(final_sum);
766
767    // Bounds-check elimination in remainder
768    let scalar_start = remaining_start + remaining_chunks * 4;
769    for i in scalar_start..len {
770        total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
771    }
772
773    total
774}
775
776// ============================================================================
777// Public API with architecture dispatch
778// ============================================================================
779
780/// SIMD-optimized dot product
781/// Uses AVX-512 > AVX2 on x86_64, NEON on ARM64/Apple Silicon
782#[inline(always)]
783pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
784    #[cfg(target_arch = "x86_64")]
785    {
786        if is_x86_feature_detected!("avx512f") {
787            unsafe { dot_product_avx512_impl(a, b) }
788        } else if is_x86_feature_detected!("avx2") {
789            unsafe { dot_product_avx2_impl(a, b) }
790        } else {
791            dot_product_scalar(a, b)
792        }
793    }
794
795    #[cfg(target_arch = "aarch64")]
796    {
797        if a.len() >= 64 {
798            unsafe { dot_product_neon_unrolled_impl(a, b) }
799        } else {
800            unsafe { dot_product_neon_impl(a, b) }
801        }
802    }
803
804    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
805    {
806        dot_product_scalar(a, b)
807    }
808}
809
810/// Legacy alias for backward compatibility
811#[inline(always)]
812pub fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
813    dot_product_simd(a, b)
814}
815
816#[cfg(target_arch = "x86_64")]
817#[target_feature(enable = "avx2")]
818unsafe fn dot_product_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
819    // SECURITY: Ensure both arrays have the same length to prevent out-of-bounds access
820    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
821
822    let len = a.len();
823    let mut sum = _mm256_setzero_ps();
824
825    let chunks = len / 8;
826    for i in 0..chunks {
827        let idx = i * 8;
828        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
829        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
830        let prod = _mm256_mul_ps(va, vb);
831        sum = _mm256_add_ps(sum, prod);
832    }
833
834    let sum_arr: [f32; 8] = std::mem::transmute(sum);
835    let mut total = sum_arr.iter().sum::<f32>();
836
837    for i in (chunks * 8)..len {
838        total += a[i] * b[i];
839    }
840
841    total
842}
843
844/// SIMD-optimized cosine similarity
845/// Uses AVX-512 > AVX2 on x86_64, NEON on ARM64/Apple Silicon
846#[inline(always)]
847pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
848    #[cfg(target_arch = "x86_64")]
849    {
850        if is_x86_feature_detected!("avx512f") {
851            unsafe { cosine_similarity_avx512_impl(a, b) }
852        } else if is_x86_feature_detected!("avx2") {
853            unsafe { cosine_similarity_avx2_impl(a, b) }
854        } else {
855            cosine_similarity_scalar(a, b)
856        }
857    }
858
859    #[cfg(target_arch = "aarch64")]
860    {
861        if a.len() >= 64 {
862            unsafe { cosine_similarity_neon_unrolled_impl(a, b) }
863        } else {
864            unsafe { cosine_similarity_neon_impl(a, b) }
865        }
866    }
867
868    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
869    {
870        cosine_similarity_scalar(a, b)
871    }
872}
873
874/// Legacy alias for backward compatibility
875#[inline(always)]
876pub fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
877    cosine_similarity_simd(a, b)
878}
879
880/// SIMD-optimized Manhattan distance
881/// Uses AVX-512 on x86_64, NEON on ARM64/Apple Silicon, scalar on other platforms
882#[inline(always)]
883pub fn manhattan_distance_simd(a: &[f32], b: &[f32]) -> f32 {
884    #[cfg(target_arch = "x86_64")]
885    {
886        if is_x86_feature_detected!("avx512f") {
887            unsafe { manhattan_distance_avx512_impl(a, b) }
888        } else {
889            manhattan_distance_scalar(a, b)
890        }
891    }
892
893    #[cfg(target_arch = "aarch64")]
894    {
895        if a.len() >= 64 {
896            unsafe { manhattan_distance_neon_unrolled_impl(a, b) }
897        } else {
898            unsafe { manhattan_distance_neon_impl(a, b) }
899        }
900    }
901
902    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
903    {
904        manhattan_distance_scalar(a, b)
905    }
906}
907
908#[cfg(target_arch = "x86_64")]
909#[target_feature(enable = "avx2")]
910unsafe fn cosine_similarity_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
911    // SECURITY: Ensure both arrays have the same length to prevent out-of-bounds access
912    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
913
914    let len = a.len();
915    let mut dot = _mm256_setzero_ps();
916    let mut norm_a = _mm256_setzero_ps();
917    let mut norm_b = _mm256_setzero_ps();
918
919    let chunks = len / 8;
920    for i in 0..chunks {
921        let idx = i * 8;
922        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
923        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
924
925        // Dot product
926        dot = _mm256_add_ps(dot, _mm256_mul_ps(va, vb));
927
928        // Norms
929        norm_a = _mm256_add_ps(norm_a, _mm256_mul_ps(va, va));
930        norm_b = _mm256_add_ps(norm_b, _mm256_mul_ps(vb, vb));
931    }
932
933    let dot_arr: [f32; 8] = std::mem::transmute(dot);
934    let norm_a_arr: [f32; 8] = std::mem::transmute(norm_a);
935    let norm_b_arr: [f32; 8] = std::mem::transmute(norm_b);
936
937    let mut dot_sum = dot_arr.iter().sum::<f32>();
938    let mut norm_a_sum = norm_a_arr.iter().sum::<f32>();
939    let mut norm_b_sum = norm_b_arr.iter().sum::<f32>();
940
941    for i in (chunks * 8)..len {
942        dot_sum += a[i] * b[i];
943        norm_a_sum += a[i] * a[i];
944        norm_b_sum += b[i] * b[i];
945    }
946
947    dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
948}
949
950// Scalar fallback implementations
951
952fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
953    a.iter()
954        .zip(b.iter())
955        .map(|(x, y)| {
956            let diff = x - y;
957            diff * diff
958        })
959        .sum::<f32>()
960        .sqrt()
961}
962
963fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
964    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
965}
966
967fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
968    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
969    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
970    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
971    dot / (norm_a * norm_b)
972}
973
974fn manhattan_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
975    a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
976}
977
978// ============================================================================
979// INT8 Quantized Operations
980// ============================================================================
981
982/// SIMD-accelerated dot product for INT8 quantized vectors
983/// Uses NEON vdotq_s32 on ARM64, AVX2 _mm256_maddubs_epi16 on x86_64
984#[inline(always)]
985pub fn dot_product_i8(a: &[i8], b: &[i8]) -> i32 {
986    #[cfg(target_arch = "x86_64")]
987    {
988        if is_x86_feature_detected!("avx2") {
989            unsafe { dot_product_i8_avx2_impl(a, b) }
990        } else {
991            dot_product_i8_scalar(a, b)
992        }
993    }
994
995    #[cfg(target_arch = "aarch64")]
996    {
997        unsafe { dot_product_i8_neon_impl(a, b) }
998    }
999
1000    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1001    {
1002        dot_product_i8_scalar(a, b)
1003    }
1004}
1005
1006/// SIMD-accelerated euclidean distance squared for INT8 quantized vectors
1007/// Returns squared distance (caller should sqrt if needed)
1008#[inline(always)]
1009pub fn euclidean_distance_squared_i8(a: &[i8], b: &[i8]) -> i32 {
1010    #[cfg(target_arch = "x86_64")]
1011    {
1012        if is_x86_feature_detected!("avx2") {
1013            unsafe { euclidean_distance_squared_i8_avx2_impl(a, b) }
1014        } else {
1015            euclidean_distance_squared_i8_scalar(a, b)
1016        }
1017    }
1018
1019    #[cfg(target_arch = "aarch64")]
1020    {
1021        unsafe { euclidean_distance_squared_i8_neon_impl(a, b) }
1022    }
1023
1024    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1025    {
1026        euclidean_distance_squared_i8_scalar(a, b)
1027    }
1028}
1029
1030/// NEON INT8 dot product using stable intrinsics
1031/// Note: Uses sign extension and multiply-add instead of vdotq_s32 for stability
1032///
1033/// # Safety
1034/// Caller must ensure a.len() == b.len()
1035#[cfg(target_arch = "aarch64")]
1036#[inline(always)]
1037unsafe fn dot_product_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
1038    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1039
1040    let len = a.len();
1041    let a_ptr = a.as_ptr();
1042    let b_ptr = b.as_ptr();
1043
1044    let mut sum = vdupq_n_s32(0);
1045
1046    // Process 8 int8s at a time (extend to i16, multiply, accumulate)
1047    let chunks = len / 8;
1048    let mut idx = 0usize;
1049
1050    for _ in 0..chunks {
1051        let va = vld1_s8(a_ptr.add(idx));
1052        let vb = vld1_s8(b_ptr.add(idx));
1053
1054        // Sign-extend to i16
1055        let va_i16 = vmovl_s8(va);
1056        let vb_i16 = vmovl_s8(vb);
1057
1058        // Multiply i16 * i16
1059        let prod_lo = vmull_s16(vget_low_s16(va_i16), vget_low_s16(vb_i16));
1060        let prod_hi = vmull_s16(vget_high_s16(va_i16), vget_high_s16(vb_i16));
1061
1062        // Accumulate
1063        sum = vaddq_s32(sum, prod_lo);
1064        sum = vaddq_s32(sum, prod_hi);
1065
1066        idx += 8;
1067    }
1068
1069    // Horizontal sum
1070    let mut total = vaddvq_s32(sum);
1071
1072    // Handle remaining elements with bounds-check elimination
1073    for i in (chunks * 8)..len {
1074        total += (*a.get_unchecked(i) as i32) * (*b.get_unchecked(i) as i32);
1075    }
1076
1077    total
1078}
1079
1080/// NEON INT8 euclidean distance squared using stable intrinsics
1081///
1082/// # Safety
1083/// Caller must ensure a.len() == b.len()
1084#[cfg(target_arch = "aarch64")]
1085#[inline(always)]
1086unsafe fn euclidean_distance_squared_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
1087    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1088
1089    let len = a.len();
1090    let a_ptr = a.as_ptr();
1091    let b_ptr = b.as_ptr();
1092
1093    let mut sum = vdupq_n_s32(0);
1094
1095    // Process 8 int8s at a time
1096    let chunks = len / 8;
1097    let mut idx = 0usize;
1098
1099    for _ in 0..chunks {
1100        let va = vld1_s8(a_ptr.add(idx));
1101        let vb = vld1_s8(b_ptr.add(idx));
1102
1103        // Sign-extend to i16
1104        let va_i16 = vmovl_s8(va);
1105        let vb_i16 = vmovl_s8(vb);
1106
1107        // Compute difference in i16
1108        let diff = vsubq_s16(va_i16, vb_i16);
1109
1110        // Square and accumulate: diff^2
1111        let prod_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
1112        let prod_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
1113
1114        sum = vaddq_s32(sum, prod_lo);
1115        sum = vaddq_s32(sum, prod_hi);
1116
1117        idx += 8;
1118    }
1119
1120    let mut total = vaddvq_s32(sum);
1121
1122    // Handle remaining elements with bounds-check elimination
1123    for i in (chunks * 8)..len {
1124        let diff = (*a.get_unchecked(i) as i32) - (*b.get_unchecked(i) as i32);
1125        total += diff * diff;
1126    }
1127
1128    total
1129}
1130
1131/// AVX2 INT8 dot product
1132#[cfg(target_arch = "x86_64")]
1133#[target_feature(enable = "avx2")]
1134unsafe fn dot_product_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
1135    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1136
1137    let len = a.len();
1138    let mut sum = _mm256_setzero_si256();
1139
1140    // Process 32 int8s at a time
1141    let chunks = len / 32;
1142    for i in 0..chunks {
1143        let idx = i * 32;
1144        let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
1145        let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
1146
1147        // For signed int8 multiply, we need to extend to i16 first
1148        let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
1149        let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
1150        let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
1151        let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
1152
1153        let prod_lo = _mm256_madd_epi16(va_lo, vb_lo);
1154        let prod_hi = _mm256_madd_epi16(va_hi, vb_hi);
1155
1156        sum = _mm256_add_epi32(sum, prod_lo);
1157        sum = _mm256_add_epi32(sum, prod_hi);
1158    }
1159
1160    // Horizontal sum
1161    let sum_arr: [i32; 8] = std::mem::transmute(sum);
1162    let mut total: i32 = sum_arr.iter().sum();
1163
1164    // Handle remaining elements
1165    for i in (chunks * 32)..len {
1166        total += (a[i] as i32) * (b[i] as i32);
1167    }
1168
1169    total
1170}
1171
1172/// AVX2 INT8 euclidean distance squared
1173#[cfg(target_arch = "x86_64")]
1174#[target_feature(enable = "avx2")]
1175unsafe fn euclidean_distance_squared_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
1176    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1177
1178    let len = a.len();
1179    let mut sum = _mm256_setzero_si256();
1180
1181    let chunks = len / 32;
1182    for i in 0..chunks {
1183        let idx = i * 32;
1184        let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
1185        let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
1186
1187        // Extend to i16, compute difference, then square
1188        let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
1189        let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
1190        let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
1191        let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
1192
1193        let diff_lo = _mm256_sub_epi16(va_lo, vb_lo);
1194        let diff_hi = _mm256_sub_epi16(va_hi, vb_hi);
1195
1196        let sq_lo = _mm256_madd_epi16(diff_lo, diff_lo);
1197        let sq_hi = _mm256_madd_epi16(diff_hi, diff_hi);
1198
1199        sum = _mm256_add_epi32(sum, sq_lo);
1200        sum = _mm256_add_epi32(sum, sq_hi);
1201    }
1202
1203    let sum_arr: [i32; 8] = std::mem::transmute(sum);
1204    let mut total: i32 = sum_arr.iter().sum();
1205
1206    for i in (chunks * 32)..len {
1207        let diff = (a[i] as i32) - (b[i] as i32);
1208        total += diff * diff;
1209    }
1210
1211    total
1212}
1213
1214/// Scalar fallback for INT8 dot product
1215fn dot_product_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
1216    a.iter()
1217        .zip(b.iter())
1218        .map(|(&x, &y)| (x as i32) * (y as i32))
1219        .sum()
1220}
1221
1222/// Scalar fallback for INT8 euclidean distance squared
1223fn euclidean_distance_squared_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
1224    a.iter()
1225        .zip(b.iter())
1226        .map(|(&x, &y)| {
1227            let diff = (x as i32) - (y as i32);
1228            diff * diff
1229        })
1230        .sum()
1231}
1232
1233// ============================================================================
1234// Batch Operations (Cache-optimized)
1235// ============================================================================
1236
1237/// Batch dot product - compute dot products of one query vector against multiple vectors
1238/// Returns results in the provided output slice
1239/// Optimized for cache locality by processing vectors in tiles
1240#[inline]
1241pub fn batch_dot_product(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1242    assert_eq!(
1243        vectors.len(),
1244        results.len(),
1245        "Output size must match vector count"
1246    );
1247
1248    // Process in tiles for better cache utilization
1249    const TILE_SIZE: usize = 16;
1250
1251    for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1252        let base_idx = chunk_idx * TILE_SIZE;
1253        for (i, vec) in chunk.iter().enumerate() {
1254            results[base_idx + i] = dot_product_simd(query, vec);
1255        }
1256    }
1257}
1258
1259/// Batch euclidean distance - compute distances from one query to multiple vectors
1260/// Returns results in the provided output slice
1261/// Optimized for cache locality
1262#[inline]
1263pub fn batch_euclidean(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1264    assert_eq!(
1265        vectors.len(),
1266        results.len(),
1267        "Output size must match vector count"
1268    );
1269
1270    const TILE_SIZE: usize = 16;
1271
1272    for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1273        let base_idx = chunk_idx * TILE_SIZE;
1274        for (i, vec) in chunk.iter().enumerate() {
1275            results[base_idx + i] = euclidean_distance_simd(query, vec);
1276        }
1277    }
1278}
1279
1280/// Batch cosine similarity - compute similarities from one query to multiple vectors
1281#[inline]
1282pub fn batch_cosine_similarity(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1283    assert_eq!(
1284        vectors.len(),
1285        results.len(),
1286        "Output size must match vector count"
1287    );
1288
1289    const TILE_SIZE: usize = 16;
1290
1291    for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1292        let base_idx = chunk_idx * TILE_SIZE;
1293        for (i, vec) in chunk.iter().enumerate() {
1294            results[base_idx + i] = cosine_similarity_simd(query, vec);
1295        }
1296    }
1297}
1298
1299/// Batch dot product with owned vectors (for convenience)
1300#[inline]
1301pub fn batch_dot_product_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
1302    let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1303    let mut results = vec![0.0; vectors.len()];
1304    batch_dot_product(query, &refs, &mut results);
1305    results
1306}
1307
1308/// Batch euclidean distance with owned vectors (for convenience)
1309#[inline]
1310pub fn batch_euclidean_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
1311    let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1312    let mut results = vec![0.0; vectors.len()];
1313    batch_euclidean(query, &refs, &mut results);
1314    results
1315}
1316
1317#[cfg(test)]
1318mod tests {
1319    use super::*;
1320
1321    #[test]
1322    fn test_euclidean_distance_simd() {
1323        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1324        let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
1325
1326        let result = euclidean_distance_simd(&a, &b);
1327        let expected = euclidean_distance_scalar(&a, &b);
1328
1329        assert!(
1330            (result - expected).abs() < 0.001,
1331            "SIMD result {} differs from scalar result {}",
1332            result,
1333            expected
1334        );
1335    }
1336
1337    #[test]
1338    fn test_euclidean_distance_large() {
1339        // Test with 128-dim vectors (common embedding size)
1340        let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
1341        let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
1342
1343        let result = euclidean_distance_simd(&a, &b);
1344        let expected = euclidean_distance_scalar(&a, &b);
1345
1346        assert!(
1347            (result - expected).abs() < 0.01,
1348            "Large vector: SIMD {} vs scalar {}",
1349            result,
1350            expected
1351        );
1352    }
1353
1354    #[test]
1355    fn test_dot_product_simd() {
1356        let a = vec![1.0; 16];
1357        let b = vec![2.0; 16];
1358
1359        let result = dot_product_simd(&a, &b);
1360        assert!((result - 32.0).abs() < 0.001);
1361    }
1362
1363    #[test]
1364    fn test_dot_product_large() {
1365        let a: Vec<f32> = (0..256).map(|i| (i % 10) as f32).collect();
1366        let b: Vec<f32> = (0..256).map(|i| ((i + 5) % 10) as f32).collect();
1367
1368        let result = dot_product_simd(&a, &b);
1369        let expected = dot_product_scalar(&a, &b);
1370
1371        assert!(
1372            (result - expected).abs() < 0.1,
1373            "Large dot product: SIMD {} vs scalar {}",
1374            result,
1375            expected
1376        );
1377    }
1378
1379    #[test]
1380    fn test_cosine_similarity_simd() {
1381        let a = vec![1.0, 0.0, 0.0];
1382        let b = vec![1.0, 0.0, 0.0];
1383
1384        let result = cosine_similarity_simd(&a, &b);
1385        assert!((result - 1.0).abs() < 0.001);
1386    }
1387
1388    #[test]
1389    fn test_cosine_similarity_orthogonal() {
1390        let a = vec![1.0, 0.0, 0.0, 0.0];
1391        let b = vec![0.0, 1.0, 0.0, 0.0];
1392
1393        let result = cosine_similarity_simd(&a, &b);
1394        assert!(
1395            result.abs() < 0.001,
1396            "Orthogonal vectors should have ~0 similarity, got {}",
1397            result
1398        );
1399    }
1400
1401    #[test]
1402    fn test_manhattan_distance_simd() {
1403        let a = vec![1.0, 2.0, 3.0, 4.0];
1404        let b = vec![5.0, 6.0, 7.0, 8.0];
1405
1406        let result = manhattan_distance_simd(&a, &b);
1407        let expected = manhattan_distance_scalar(&a, &b);
1408
1409        assert!(
1410            (result - expected).abs() < 0.001,
1411            "Manhattan: SIMD {} vs scalar {}",
1412            result,
1413            expected
1414        );
1415        assert!((result - 16.0).abs() < 0.001); // |4| + |4| + |4| + |4| = 16
1416    }
1417
1418    #[test]
1419    fn test_non_aligned_lengths() {
1420        // Test vectors not aligned to SIMD width (4 for NEON, 8 for AVX2)
1421        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]; // 7 elements
1422        let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1423
1424        let result = euclidean_distance_simd(&a, &b);
1425        let expected = euclidean_distance_scalar(&a, &b);
1426
1427        assert!(
1428            (result - expected).abs() < 0.001,
1429            "Non-aligned: SIMD {} vs scalar {}",
1430            result,
1431            expected
1432        );
1433    }
1434
1435    // Legacy function tests (ensure backward compatibility)
1436    #[test]
1437    fn test_legacy_avx2_aliases() {
1438        let a = vec![1.0, 2.0, 3.0, 4.0];
1439        let b = vec![5.0, 6.0, 7.0, 8.0];
1440
1441        // These should work identically to the _simd versions
1442        let _ = euclidean_distance_avx2(&a, &b);
1443        let _ = dot_product_avx2(&a, &b);
1444        let _ = cosine_similarity_avx2(&a, &b);
1445    }
1446
1447    // INT8 quantized operation tests
1448    #[test]
1449    fn test_dot_product_i8() {
1450        let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1451        let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
1452
1453        let result = dot_product_i8(&a, &b);
1454        let expected = dot_product_i8_scalar(&a, &b);
1455
1456        assert_eq!(
1457            result, expected,
1458            "INT8 dot product: SIMD {} vs scalar {}",
1459            result, expected
1460        );
1461    }
1462
1463    #[test]
1464    fn test_dot_product_i8_large() {
1465        // Test with 128 elements (common for quantized embeddings)
1466        let a: Vec<i8> = (0..128)
1467            .map(|i| ((i % 256) as i8).wrapping_sub(64))
1468            .collect();
1469        let b: Vec<i8> = (0..128)
1470            .map(|i| (((i + 10) % 256) as i8).wrapping_sub(64))
1471            .collect();
1472
1473        let result = dot_product_i8(&a, &b);
1474        let expected = dot_product_i8_scalar(&a, &b);
1475
1476        assert_eq!(
1477            result, expected,
1478            "Large INT8 dot product: SIMD {} vs scalar {}",
1479            result, expected
1480        );
1481    }
1482
1483    #[test]
1484    fn test_euclidean_distance_squared_i8() {
1485        let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1486        let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
1487
1488        let result = euclidean_distance_squared_i8(&a, &b);
1489        let expected = euclidean_distance_squared_i8_scalar(&a, &b);
1490
1491        assert_eq!(
1492            result, expected,
1493            "INT8 euclidean^2: SIMD {} vs scalar {}",
1494            result, expected
1495        );
1496        // Each diff is 1, so 16 diffs squared = 16
1497        assert_eq!(result, 16, "Expected 16, got {}", result);
1498    }
1499
1500    #[test]
1501    fn test_euclidean_distance_squared_i8_large() {
1502        let a: Vec<i8> = (0..128)
1503            .map(|i| ((i % 256) as i8).wrapping_sub(64))
1504            .collect();
1505        let b: Vec<i8> = (0..128)
1506            .map(|i| (((i + 5) % 256) as i8).wrapping_sub(64))
1507            .collect();
1508
1509        let result = euclidean_distance_squared_i8(&a, &b);
1510        let expected = euclidean_distance_squared_i8_scalar(&a, &b);
1511
1512        assert_eq!(
1513            result, expected,
1514            "Large INT8 euclidean^2: SIMD {} vs scalar {}",
1515            result, expected
1516        );
1517    }
1518
1519    // Batch operation tests
1520    #[test]
1521    fn test_batch_dot_product() {
1522        let query = vec![1.0, 2.0, 3.0, 4.0];
1523        let v1 = vec![1.0, 0.0, 0.0, 0.0];
1524        let v2 = vec![0.0, 1.0, 0.0, 0.0];
1525        let v3 = vec![0.0, 0.0, 1.0, 0.0];
1526        let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
1527        let mut results = vec![0.0; 3];
1528
1529        batch_dot_product(&query, &vectors, &mut results);
1530
1531        assert!((results[0] - 1.0).abs() < 0.001);
1532        assert!((results[1] - 2.0).abs() < 0.001);
1533        assert!((results[2] - 3.0).abs() < 0.001);
1534    }
1535
1536    #[test]
1537    fn test_batch_euclidean() {
1538        let query = vec![0.0, 0.0, 0.0, 0.0];
1539        let v1 = vec![3.0, 4.0, 0.0, 0.0];
1540        let v2 = vec![0.0, 0.0, 5.0, 12.0];
1541        let vectors: Vec<&[f32]> = vec![&v1, &v2];
1542        let mut results = vec![0.0; 2];
1543
1544        batch_euclidean(&query, &vectors, &mut results);
1545
1546        assert!(
1547            (results[0] - 5.0).abs() < 0.001,
1548            "Expected 5.0, got {}",
1549            results[0]
1550        );
1551        assert!(
1552            (results[1] - 13.0).abs() < 0.001,
1553            "Expected 13.0, got {}",
1554            results[1]
1555        );
1556    }
1557
1558    #[test]
1559    fn test_batch_cosine_similarity() {
1560        let query = vec![1.0, 0.0, 0.0, 0.0];
1561        let v1 = vec![1.0, 0.0, 0.0, 0.0]; // Same direction
1562        let v2 = vec![0.0, 1.0, 0.0, 0.0]; // Orthogonal
1563        let v3 = vec![-1.0, 0.0, 0.0, 0.0]; // Opposite
1564        let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
1565        let mut results = vec![0.0; 3];
1566
1567        batch_cosine_similarity(&query, &vectors, &mut results);
1568
1569        assert!(
1570            (results[0] - 1.0).abs() < 0.001,
1571            "Same direction should be 1.0"
1572        );
1573        assert!(results[1].abs() < 0.001, "Orthogonal should be 0.0");
1574        assert!((results[2] + 1.0).abs() < 0.001, "Opposite should be -1.0");
1575    }
1576
1577    #[test]
1578    fn test_batch_owned_convenience() {
1579        let query = vec![1.0, 2.0, 3.0, 4.0];
1580        let vectors = vec![vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]];
1581
1582        let results = batch_dot_product_owned(&query, &vectors);
1583        assert_eq!(results.len(), 2);
1584        assert!((results[0] - 1.0).abs() < 0.001);
1585        assert!((results[1] - 2.0).abs() < 0.001);
1586    }
1587
1588    #[test]
1589    fn test_unrolled_vs_non_unrolled_consistency() {
1590        // Test that unrolled and non-unrolled implementations produce same results
1591        let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
1592        let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
1593
1594        let result = euclidean_distance_simd(&a, &b);
1595        let expected = euclidean_distance_scalar(&a, &b);
1596
1597        assert!(
1598            (result - expected).abs() < 0.01,
1599            "Unrolled consistency: SIMD {} vs scalar {}",
1600            result,
1601            expected
1602        );
1603    }
1604}