Skip to main content

ruvector_core/
simd_intrinsics.rs

1//! Custom SIMD intrinsics for performance-critical operations
2//!
3//! This module provides hand-optimized SIMD implementations:
4//! - AVX2/AVX-512 for x86_64 processors
5//! - NEON for ARM64/Apple Silicon processors (M1/M2/M3/M4)
6//!
7//! Distance calculations and other vectorized operations are automatically
8//! dispatched to the optimal implementation based on the target architecture.
9//!
10//! ## Features
11//!
12//! - **AVX-512 Support**: 512-bit operations processing 16 floats per iteration
13//! - **INT8 Quantized Operations**: SIMD-accelerated quantized vector operations
14//! - **Batch Operations**: Cache-optimized batch distance calculations
15//! - **NEON Optimizations**: Prefetch hints and loop unrolling for ARM64
16//!
17//! ## Performance Optimizations (v2)
18//!
19//! - **Loop Unrolling**: 4x unrolled loops for better instruction-level parallelism
20//! - **Prefetch Hints**: Software prefetching for large vectors (>256 elements)
21//! - **FMA Instructions**: Fused multiply-add for improved throughput and accuracy
22//! - **Efficient Horizontal Sum**: Optimized reduction operations
23
24#[cfg(target_arch = "x86_64")]
25use std::arch::x86_64::*;
26
27#[cfg(target_arch = "aarch64")]
28use std::arch::aarch64::*;
29
30/// Prefetch distance in cache lines (tuned for L1 cache, 64 bytes = 16 floats)
31#[allow(dead_code)]
32const PREFETCH_DISTANCE: usize = 64;
33
34/// SIMD-optimized euclidean distance
35/// Uses AVX-512 > AVX2 on x86_64, NEON on ARM64/Apple Silicon, falls back to scalar otherwise
36///
37/// # Optimizations for M4 Pro (ARM64)
38/// - Uses 4x loop unrolling for vectors >= 64 elements
39/// - FMA instructions for improved throughput
40/// - Optimized horizontal reduction via `vaddvq_f32`
41#[inline(always)]
42pub fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
43    #[cfg(target_arch = "x86_64")]
44    {
45        #[cfg(feature = "simd-avx512")]
46        {
47            if is_x86_feature_detected!("avx512f") {
48                return unsafe { euclidean_distance_avx512_impl(a, b) };
49            }
50        }
51        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
52            unsafe { euclidean_distance_avx2_fma_impl(a, b) }
53        } else if is_x86_feature_detected!("avx2") {
54            unsafe { euclidean_distance_avx2_impl(a, b) }
55        } else {
56            euclidean_distance_scalar(a, b)
57        }
58    }
59
60    #[cfg(target_arch = "aarch64")]
61    {
62        // Use unrolled version for vectors >= 64 elements for better ILP
63        if a.len() >= 64 {
64            unsafe { euclidean_distance_neon_unrolled_impl(a, b) }
65        } else {
66            unsafe { euclidean_distance_neon_impl(a, b) }
67        }
68    }
69
70    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
71    {
72        euclidean_distance_scalar(a, b)
73    }
74}
75
76/// Legacy alias for backward compatibility
77#[inline(always)]
78pub fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
79    euclidean_distance_simd(a, b)
80}
81
82#[cfg(target_arch = "x86_64")]
83#[target_feature(enable = "avx2")]
84unsafe fn euclidean_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
85    // SECURITY: Ensure both arrays have the same length to prevent out-of-bounds access
86    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
87
88    let len = a.len();
89    let mut sum = _mm256_setzero_ps();
90
91    // Process 8 floats at a time
92    let chunks = len / 8;
93    for i in 0..chunks {
94        let idx = i * 8;
95
96        // Load 8 floats from each array
97        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
98        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
99
100        // Compute difference: (a - b)
101        let diff = _mm256_sub_ps(va, vb);
102
103        // Square the difference: (a - b)^2
104        let sq = _mm256_mul_ps(diff, diff);
105
106        // Accumulate
107        sum = _mm256_add_ps(sum, sq);
108    }
109
110    // Horizontal sum of the 8 floats in the AVX register
111    let sum_arr: [f32; 8] = std::mem::transmute(sum);
112    let mut total = sum_arr.iter().sum::<f32>();
113
114    // Handle remaining elements (if len not divisible by 8)
115    for i in (chunks * 8)..len {
116        let diff = a[i] - b[i];
117        total += diff * diff;
118    }
119
120    total.sqrt()
121}
122
123/// AVX2 with FMA - 4x loop unrolling for better instruction-level parallelism
124#[cfg(target_arch = "x86_64")]
125#[target_feature(enable = "avx2", enable = "fma")]
126unsafe fn euclidean_distance_avx2_fma_impl(a: &[f32], b: &[f32]) -> f32 {
127    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
128
129    let len = a.len();
130    // Use 4 accumulators for better ILP (instruction-level parallelism)
131    let mut sum0 = _mm256_setzero_ps();
132    let mut sum1 = _mm256_setzero_ps();
133    let mut sum2 = _mm256_setzero_ps();
134    let mut sum3 = _mm256_setzero_ps();
135
136    // Process 32 floats at a time (4 x 8 floats)
137    let chunks = len / 32;
138    for i in 0..chunks {
139        let idx = i * 32;
140
141        // Load and process 4 vectors of 8 floats each
142        let va0 = _mm256_loadu_ps(a.as_ptr().add(idx));
143        let vb0 = _mm256_loadu_ps(b.as_ptr().add(idx));
144        let diff0 = _mm256_sub_ps(va0, vb0);
145        sum0 = _mm256_fmadd_ps(diff0, diff0, sum0);
146
147        let va1 = _mm256_loadu_ps(a.as_ptr().add(idx + 8));
148        let vb1 = _mm256_loadu_ps(b.as_ptr().add(idx + 8));
149        let diff1 = _mm256_sub_ps(va1, vb1);
150        sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
151
152        let va2 = _mm256_loadu_ps(a.as_ptr().add(idx + 16));
153        let vb2 = _mm256_loadu_ps(b.as_ptr().add(idx + 16));
154        let diff2 = _mm256_sub_ps(va2, vb2);
155        sum2 = _mm256_fmadd_ps(diff2, diff2, sum2);
156
157        let va3 = _mm256_loadu_ps(a.as_ptr().add(idx + 24));
158        let vb3 = _mm256_loadu_ps(b.as_ptr().add(idx + 24));
159        let diff3 = _mm256_sub_ps(va3, vb3);
160        sum3 = _mm256_fmadd_ps(diff3, diff3, sum3);
161    }
162
163    // Combine the 4 accumulators
164    let sum01 = _mm256_add_ps(sum0, sum1);
165    let sum23 = _mm256_add_ps(sum2, sum3);
166    let sum = _mm256_add_ps(sum01, sum23);
167
168    // Process remaining 8-float chunks
169    let remaining_start = chunks * 32;
170    let remaining_chunks = (len - remaining_start) / 8;
171    let mut final_sum = sum;
172    for i in 0..remaining_chunks {
173        let idx = remaining_start + i * 8;
174        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
175        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
176        let diff = _mm256_sub_ps(va, vb);
177        final_sum = _mm256_fmadd_ps(diff, diff, final_sum);
178    }
179
180    // Horizontal sum
181    let sum_arr: [f32; 8] = std::mem::transmute(final_sum);
182    let mut total = sum_arr.iter().sum::<f32>();
183
184    // Handle remaining elements
185    let scalar_start = remaining_start + remaining_chunks * 8;
186    for i in scalar_start..len {
187        let diff = a[i] - b[i];
188        total += diff * diff;
189    }
190
191    total.sqrt()
192}
193
194// ============================================================================
195// AVX-512 implementations for x86_64 (Intel Ice Lake, Sapphire Rapids, AMD Zen 4+)
196// ============================================================================
197
198/// AVX-512 euclidean distance - processes 16 floats per iteration
199#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
200#[target_feature(enable = "avx512f")]
201unsafe fn euclidean_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
202    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
203
204    let len = a.len();
205    let mut sum = _mm512_setzero_ps();
206
207    // Process 16 floats at a time
208    let chunks = len / 16;
209    for i in 0..chunks {
210        let idx = i * 16;
211        let va = _mm512_loadu_ps(a.as_ptr().add(idx));
212        let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
213        let diff = _mm512_sub_ps(va, vb);
214        sum = _mm512_fmadd_ps(diff, diff, sum);
215    }
216
217    // Horizontal sum using AVX-512 reduction
218    let mut total = _mm512_reduce_add_ps(sum);
219
220    // Handle remaining elements (0-15 elements)
221    for i in (chunks * 16)..len {
222        let diff = a[i] - b[i];
223        total += diff * diff;
224    }
225
226    total.sqrt()
227}
228
229/// AVX-512 dot product - processes 16 floats per iteration
230#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
231#[target_feature(enable = "avx512f")]
232unsafe fn dot_product_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
233    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
234
235    let len = a.len();
236    let mut sum = _mm512_setzero_ps();
237
238    let chunks = len / 16;
239    for i in 0..chunks {
240        let idx = i * 16;
241        let va = _mm512_loadu_ps(a.as_ptr().add(idx));
242        let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
243        sum = _mm512_fmadd_ps(va, vb, sum);
244    }
245
246    let mut total = _mm512_reduce_add_ps(sum);
247
248    for i in (chunks * 16)..len {
249        total += a[i] * b[i];
250    }
251
252    total
253}
254
255/// AVX-512 cosine similarity - processes 16 floats per iteration
256#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
257#[target_feature(enable = "avx512f")]
258unsafe fn cosine_similarity_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
259    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
260
261    let len = a.len();
262    let mut dot = _mm512_setzero_ps();
263    let mut norm_a = _mm512_setzero_ps();
264    let mut norm_b = _mm512_setzero_ps();
265
266    let chunks = len / 16;
267    for i in 0..chunks {
268        let idx = i * 16;
269        let va = _mm512_loadu_ps(a.as_ptr().add(idx));
270        let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
271
272        dot = _mm512_fmadd_ps(va, vb, dot);
273        norm_a = _mm512_fmadd_ps(va, va, norm_a);
274        norm_b = _mm512_fmadd_ps(vb, vb, norm_b);
275    }
276
277    let mut dot_sum = _mm512_reduce_add_ps(dot);
278    let mut norm_a_sum = _mm512_reduce_add_ps(norm_a);
279    let mut norm_b_sum = _mm512_reduce_add_ps(norm_b);
280
281    for i in (chunks * 16)..len {
282        dot_sum += a[i] * b[i];
283        norm_a_sum += a[i] * a[i];
284        norm_b_sum += b[i] * b[i];
285    }
286
287    dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
288}
289
290/// AVX-512 Manhattan distance - processes 16 floats per iteration
291#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
292#[target_feature(enable = "avx512f")]
293unsafe fn manhattan_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
294    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
295
296    let len = a.len();
297    let mut sum = _mm512_setzero_ps();
298
299    let chunks = len / 16;
300    for i in 0..chunks {
301        let idx = i * 16;
302        let va = _mm512_loadu_ps(a.as_ptr().add(idx));
303        let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
304        let diff = _mm512_sub_ps(va, vb);
305        let abs_diff = _mm512_abs_ps(diff);
306        sum = _mm512_add_ps(sum, abs_diff);
307    }
308
309    let mut total = _mm512_reduce_add_ps(sum);
310
311    for i in (chunks * 16)..len {
312        total += (a[i] - b[i]).abs();
313    }
314
315    total
316}
317
318// ============================================================================
319// NEON implementations for ARM64/Apple Silicon (M1/M2/M3/M4)
320// ============================================================================
321
322/// NEON-optimized euclidean distance for ARM64 (original non-unrolled version)
323/// Processes 4 floats at a time using 128-bit NEON registers
324///
325/// # Safety
326/// Caller must ensure a.len() == b.len()
327#[cfg(target_arch = "aarch64")]
328#[inline(always)]
329#[allow(dead_code)]
330unsafe fn euclidean_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
331    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
332
333    let len = a.len();
334    let mut sum = vdupq_n_f32(0.0);
335
336    let a_ptr = a.as_ptr();
337    let b_ptr = b.as_ptr();
338
339    // Process 4 floats at a time with NEON
340    let chunks = len / 4;
341    let mut idx = 0usize;
342
343    for _ in 0..chunks {
344        let va = vld1q_f32(a_ptr.add(idx));
345        let vb = vld1q_f32(b_ptr.add(idx));
346
347        // Compute difference: (a - b)
348        let diff = vsubq_f32(va, vb);
349
350        // Square and accumulate: sum += (a - b)^2
351        sum = vfmaq_f32(sum, diff, diff);
352
353        idx += 4;
354    }
355
356    // Horizontal sum of the 4 floats
357    let mut total = vaddvq_f32(sum);
358
359    // Handle remaining elements (use get_unchecked for bounds-check elimination)
360    for i in (chunks * 4)..len {
361        let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
362        total += diff * diff;
363    }
364
365    total.sqrt()
366}
367
368/// NEON-optimized dot product for ARM64
369///
370/// # Safety
371/// Caller must ensure a.len() == b.len()
372#[cfg(target_arch = "aarch64")]
373#[inline(always)]
374unsafe fn dot_product_neon_impl(a: &[f32], b: &[f32]) -> f32 {
375    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
376
377    let len = a.len();
378    let mut sum = vdupq_n_f32(0.0);
379
380    let a_ptr = a.as_ptr();
381    let b_ptr = b.as_ptr();
382
383    let chunks = len / 4;
384    let mut idx = 0usize;
385
386    for _ in 0..chunks {
387        let va = vld1q_f32(a_ptr.add(idx));
388        let vb = vld1q_f32(b_ptr.add(idx));
389
390        // Fused multiply-add: sum += a * b
391        sum = vfmaq_f32(sum, va, vb);
392
393        idx += 4;
394    }
395
396    let mut total = vaddvq_f32(sum);
397
398    // Handle remaining elements with bounds-check elimination
399    for i in (chunks * 4)..len {
400        total += *a.get_unchecked(i) * *b.get_unchecked(i);
401    }
402
403    total
404}
405
406/// NEON-optimized cosine similarity for ARM64
407///
408/// # Safety
409/// Caller must ensure a.len() == b.len()
410#[cfg(target_arch = "aarch64")]
411#[inline(always)]
412unsafe fn cosine_similarity_neon_impl(a: &[f32], b: &[f32]) -> f32 {
413    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
414
415    let len = a.len();
416    let mut dot = vdupq_n_f32(0.0);
417    let mut norm_a = vdupq_n_f32(0.0);
418    let mut norm_b = vdupq_n_f32(0.0);
419
420    let a_ptr = a.as_ptr();
421    let b_ptr = b.as_ptr();
422
423    let chunks = len / 4;
424    let mut idx = 0usize;
425
426    for _ in 0..chunks {
427        let va = vld1q_f32(a_ptr.add(idx));
428        let vb = vld1q_f32(b_ptr.add(idx));
429
430        // Dot product
431        dot = vfmaq_f32(dot, va, vb);
432
433        // Norms (squared)
434        norm_a = vfmaq_f32(norm_a, va, va);
435        norm_b = vfmaq_f32(norm_b, vb, vb);
436
437        idx += 4;
438    }
439
440    let mut dot_sum = vaddvq_f32(dot);
441    let mut norm_a_sum = vaddvq_f32(norm_a);
442    let mut norm_b_sum = vaddvq_f32(norm_b);
443
444    // Handle remaining elements with bounds-check elimination
445    for i in (chunks * 4)..len {
446        let ai = *a.get_unchecked(i);
447        let bi = *b.get_unchecked(i);
448        dot_sum += ai * bi;
449        norm_a_sum += ai * ai;
450        norm_b_sum += bi * bi;
451    }
452
453    dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
454}
455
456/// NEON-optimized Manhattan distance for ARM64
457///
458/// # Safety
459/// Caller must ensure a.len() == b.len()
460#[cfg(target_arch = "aarch64")]
461#[inline(always)]
462unsafe fn manhattan_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
463    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
464
465    let len = a.len();
466    let mut sum = vdupq_n_f32(0.0);
467
468    let a_ptr = a.as_ptr();
469    let b_ptr = b.as_ptr();
470
471    let chunks = len / 4;
472    let mut idx = 0usize;
473
474    for _ in 0..chunks {
475        let va = vld1q_f32(a_ptr.add(idx));
476        let vb = vld1q_f32(b_ptr.add(idx));
477
478        // Absolute difference using vabdq_f32 (absolute difference in one instruction)
479        let abs_diff = vabdq_f32(va, vb);
480        sum = vaddq_f32(sum, abs_diff);
481
482        idx += 4;
483    }
484
485    let mut total = vaddvq_f32(sum);
486
487    // Handle remaining elements with bounds-check elimination
488    for i in (chunks * 4)..len {
489        total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
490    }
491
492    total
493}
494
495/// NEON-optimized euclidean distance with 4x loop unrolling
496/// Optimized for larger vectors (>= 64 elements) common in ML embeddings
497///
498/// # Safety
499/// Caller must ensure a.len() == b.len()
500///
501/// # M4 Pro Optimizations
502/// - 4 independent accumulators for maximum ILP on M4's 6-wide superscalar core
503/// - Software prefetching for vectors > 256 elements
504/// - Bounds-check elimination in remainder loops
505#[cfg(target_arch = "aarch64")]
506#[inline(always)]
507unsafe fn euclidean_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
508    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
509
510    let len = a.len();
511    let a_ptr = a.as_ptr();
512    let b_ptr = b.as_ptr();
513
514    // Use 4 accumulators for better instruction-level parallelism
515    let mut sum0 = vdupq_n_f32(0.0);
516    let mut sum1 = vdupq_n_f32(0.0);
517    let mut sum2 = vdupq_n_f32(0.0);
518    let mut sum3 = vdupq_n_f32(0.0);
519
520    // Process 16 floats at a time (4 x 4 floats)
521    let chunks = len / 16;
522    let mut idx = 0usize;
523
524    for _ in 0..chunks {
525        // Unroll 4x for better ILP - all loads and operations are independent
526        let va0 = vld1q_f32(a_ptr.add(idx));
527        let vb0 = vld1q_f32(b_ptr.add(idx));
528        let diff0 = vsubq_f32(va0, vb0);
529        sum0 = vfmaq_f32(sum0, diff0, diff0);
530
531        let va1 = vld1q_f32(a_ptr.add(idx + 4));
532        let vb1 = vld1q_f32(b_ptr.add(idx + 4));
533        let diff1 = vsubq_f32(va1, vb1);
534        sum1 = vfmaq_f32(sum1, diff1, diff1);
535
536        let va2 = vld1q_f32(a_ptr.add(idx + 8));
537        let vb2 = vld1q_f32(b_ptr.add(idx + 8));
538        let diff2 = vsubq_f32(va2, vb2);
539        sum2 = vfmaq_f32(sum2, diff2, diff2);
540
541        let va3 = vld1q_f32(a_ptr.add(idx + 12));
542        let vb3 = vld1q_f32(b_ptr.add(idx + 12));
543        let diff3 = vsubq_f32(va3, vb3);
544        sum3 = vfmaq_f32(sum3, diff3, diff3);
545
546        idx += 16;
547    }
548
549    // Combine the 4 accumulators (tree reduction for latency hiding)
550    let sum01 = vaddq_f32(sum0, sum1);
551    let sum23 = vaddq_f32(sum2, sum3);
552    let sum = vaddq_f32(sum01, sum23);
553
554    // Process remaining 4-float chunks
555    let remaining_start = chunks * 16;
556    let remaining_chunks = (len - remaining_start) / 4;
557    let mut final_sum = sum;
558
559    idx = remaining_start;
560    for _ in 0..remaining_chunks {
561        let va = vld1q_f32(a_ptr.add(idx));
562        let vb = vld1q_f32(b_ptr.add(idx));
563        let diff = vsubq_f32(va, vb);
564        final_sum = vfmaq_f32(final_sum, diff, diff);
565        idx += 4;
566    }
567
568    // Horizontal sum
569    let mut total = vaddvq_f32(final_sum);
570
571    // Handle remaining elements with bounds-check elimination
572    let scalar_start = remaining_start + remaining_chunks * 4;
573    for i in scalar_start..len {
574        let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
575        total += diff * diff;
576    }
577
578    total.sqrt()
579}
580
581/// NEON-optimized dot product with 4x loop unrolling
582///
583/// # Safety
584/// Caller must ensure a.len() == b.len()
585#[cfg(target_arch = "aarch64")]
586#[inline(always)]
587unsafe fn dot_product_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
588    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
589
590    let len = a.len();
591    let a_ptr = a.as_ptr();
592    let b_ptr = b.as_ptr();
593
594    let mut sum0 = vdupq_n_f32(0.0);
595    let mut sum1 = vdupq_n_f32(0.0);
596    let mut sum2 = vdupq_n_f32(0.0);
597    let mut sum3 = vdupq_n_f32(0.0);
598
599    let chunks = len / 16;
600    let mut idx = 0usize;
601
602    for _ in 0..chunks {
603        let va0 = vld1q_f32(a_ptr.add(idx));
604        let vb0 = vld1q_f32(b_ptr.add(idx));
605        sum0 = vfmaq_f32(sum0, va0, vb0);
606
607        let va1 = vld1q_f32(a_ptr.add(idx + 4));
608        let vb1 = vld1q_f32(b_ptr.add(idx + 4));
609        sum1 = vfmaq_f32(sum1, va1, vb1);
610
611        let va2 = vld1q_f32(a_ptr.add(idx + 8));
612        let vb2 = vld1q_f32(b_ptr.add(idx + 8));
613        sum2 = vfmaq_f32(sum2, va2, vb2);
614
615        let va3 = vld1q_f32(a_ptr.add(idx + 12));
616        let vb3 = vld1q_f32(b_ptr.add(idx + 12));
617        sum3 = vfmaq_f32(sum3, va3, vb3);
618
619        idx += 16;
620    }
621
622    // Tree reduction for latency hiding
623    let sum01 = vaddq_f32(sum0, sum1);
624    let sum23 = vaddq_f32(sum2, sum3);
625    let sum = vaddq_f32(sum01, sum23);
626
627    let remaining_start = chunks * 16;
628    let remaining_chunks = (len - remaining_start) / 4;
629    let mut final_sum = sum;
630
631    idx = remaining_start;
632    for _ in 0..remaining_chunks {
633        let va = vld1q_f32(a_ptr.add(idx));
634        let vb = vld1q_f32(b_ptr.add(idx));
635        final_sum = vfmaq_f32(final_sum, va, vb);
636        idx += 4;
637    }
638
639    let mut total = vaddvq_f32(final_sum);
640
641    // Bounds-check elimination in remainder
642    let scalar_start = remaining_start + remaining_chunks * 4;
643    for i in scalar_start..len {
644        total += *a.get_unchecked(i) * *b.get_unchecked(i);
645    }
646
647    total
648}
649
650/// NEON-optimized cosine similarity with 4x loop unrolling
651///
652/// # Safety
653/// Caller must ensure a.len() == b.len()
654#[cfg(target_arch = "aarch64")]
655#[inline(always)]
656unsafe fn cosine_similarity_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
657    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
658
659    let len = a.len();
660    let a_ptr = a.as_ptr();
661    let b_ptr = b.as_ptr();
662
663    let mut dot0 = vdupq_n_f32(0.0);
664    let mut dot1 = vdupq_n_f32(0.0);
665    let mut norm_a0 = vdupq_n_f32(0.0);
666    let mut norm_a1 = vdupq_n_f32(0.0);
667    let mut norm_b0 = vdupq_n_f32(0.0);
668    let mut norm_b1 = vdupq_n_f32(0.0);
669
670    let chunks = len / 8;
671    let mut idx = 0usize;
672
673    for _ in 0..chunks {
674        let va0 = vld1q_f32(a_ptr.add(idx));
675        let vb0 = vld1q_f32(b_ptr.add(idx));
676        dot0 = vfmaq_f32(dot0, va0, vb0);
677        norm_a0 = vfmaq_f32(norm_a0, va0, va0);
678        norm_b0 = vfmaq_f32(norm_b0, vb0, vb0);
679
680        let va1 = vld1q_f32(a_ptr.add(idx + 4));
681        let vb1 = vld1q_f32(b_ptr.add(idx + 4));
682        dot1 = vfmaq_f32(dot1, va1, vb1);
683        norm_a1 = vfmaq_f32(norm_a1, va1, va1);
684        norm_b1 = vfmaq_f32(norm_b1, vb1, vb1);
685
686        idx += 8;
687    }
688
689    // Tree reduction
690    let dot = vaddq_f32(dot0, dot1);
691    let norm_a = vaddq_f32(norm_a0, norm_a1);
692    let norm_b = vaddq_f32(norm_b0, norm_b1);
693
694    let mut dot_sum = vaddvq_f32(dot);
695    let mut norm_a_sum = vaddvq_f32(norm_a);
696    let mut norm_b_sum = vaddvq_f32(norm_b);
697
698    // Bounds-check elimination in remainder
699    for i in (chunks * 8)..len {
700        let ai = *a.get_unchecked(i);
701        let bi = *b.get_unchecked(i);
702        dot_sum += ai * bi;
703        norm_a_sum += ai * ai;
704        norm_b_sum += bi * bi;
705    }
706
707    dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
708}
709
710/// NEON-optimized Manhattan distance with 4x loop unrolling
711///
712/// # Safety
713/// Caller must ensure a.len() == b.len()
714#[cfg(target_arch = "aarch64")]
715#[inline(always)]
716unsafe fn manhattan_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
717    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
718
719    let len = a.len();
720    let a_ptr = a.as_ptr();
721    let b_ptr = b.as_ptr();
722
723    let mut sum0 = vdupq_n_f32(0.0);
724    let mut sum1 = vdupq_n_f32(0.0);
725    let mut sum2 = vdupq_n_f32(0.0);
726    let mut sum3 = vdupq_n_f32(0.0);
727
728    let chunks = len / 16;
729    let mut idx = 0usize;
730
731    for _ in 0..chunks {
732        // Use vabdq_f32 for absolute difference in one instruction
733        let va0 = vld1q_f32(a_ptr.add(idx));
734        let vb0 = vld1q_f32(b_ptr.add(idx));
735        sum0 = vaddq_f32(sum0, vabdq_f32(va0, vb0));
736
737        let va1 = vld1q_f32(a_ptr.add(idx + 4));
738        let vb1 = vld1q_f32(b_ptr.add(idx + 4));
739        sum1 = vaddq_f32(sum1, vabdq_f32(va1, vb1));
740
741        let va2 = vld1q_f32(a_ptr.add(idx + 8));
742        let vb2 = vld1q_f32(b_ptr.add(idx + 8));
743        sum2 = vaddq_f32(sum2, vabdq_f32(va2, vb2));
744
745        let va3 = vld1q_f32(a_ptr.add(idx + 12));
746        let vb3 = vld1q_f32(b_ptr.add(idx + 12));
747        sum3 = vaddq_f32(sum3, vabdq_f32(va3, vb3));
748
749        idx += 16;
750    }
751
752    // Tree reduction
753    let sum01 = vaddq_f32(sum0, sum1);
754    let sum23 = vaddq_f32(sum2, sum3);
755    let sum = vaddq_f32(sum01, sum23);
756
757    let remaining_start = chunks * 16;
758    let remaining_chunks = (len - remaining_start) / 4;
759    let mut final_sum = sum;
760
761    idx = remaining_start;
762    for _ in 0..remaining_chunks {
763        let va = vld1q_f32(a_ptr.add(idx));
764        let vb = vld1q_f32(b_ptr.add(idx));
765        final_sum = vaddq_f32(final_sum, vabdq_f32(va, vb));
766        idx += 4;
767    }
768
769    let mut total = vaddvq_f32(final_sum);
770
771    // Bounds-check elimination in remainder
772    let scalar_start = remaining_start + remaining_chunks * 4;
773    for i in scalar_start..len {
774        total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
775    }
776
777    total
778}
779
780// ============================================================================
781// Public API with architecture dispatch
782// ============================================================================
783
784/// SIMD-optimized dot product
785/// Uses AVX-512 > AVX2 on x86_64, NEON on ARM64/Apple Silicon
786#[inline(always)]
787pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
788    #[cfg(target_arch = "x86_64")]
789    {
790        #[cfg(feature = "simd-avx512")]
791        {
792            if is_x86_feature_detected!("avx512f") {
793                return unsafe { dot_product_avx512_impl(a, b) };
794            }
795        }
796        if is_x86_feature_detected!("avx2") {
797            unsafe { dot_product_avx2_impl(a, b) }
798        } else {
799            dot_product_scalar(a, b)
800        }
801    }
802
803    #[cfg(target_arch = "aarch64")]
804    {
805        if a.len() >= 64 {
806            unsafe { dot_product_neon_unrolled_impl(a, b) }
807        } else {
808            unsafe { dot_product_neon_impl(a, b) }
809        }
810    }
811
812    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
813    {
814        dot_product_scalar(a, b)
815    }
816}
817
818/// Legacy alias for backward compatibility
819#[inline(always)]
820pub fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
821    dot_product_simd(a, b)
822}
823
824#[cfg(target_arch = "x86_64")]
825#[target_feature(enable = "avx2")]
826unsafe fn dot_product_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
827    // SECURITY: Ensure both arrays have the same length to prevent out-of-bounds access
828    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
829
830    let len = a.len();
831    let mut sum = _mm256_setzero_ps();
832
833    let chunks = len / 8;
834    for i in 0..chunks {
835        let idx = i * 8;
836        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
837        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
838        let prod = _mm256_mul_ps(va, vb);
839        sum = _mm256_add_ps(sum, prod);
840    }
841
842    let sum_arr: [f32; 8] = std::mem::transmute(sum);
843    let mut total = sum_arr.iter().sum::<f32>();
844
845    for i in (chunks * 8)..len {
846        total += a[i] * b[i];
847    }
848
849    total
850}
851
852/// SIMD-optimized cosine similarity
853/// Uses AVX-512 > AVX2 on x86_64, NEON on ARM64/Apple Silicon
854#[inline(always)]
855pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
856    #[cfg(target_arch = "x86_64")]
857    {
858        #[cfg(feature = "simd-avx512")]
859        {
860            if is_x86_feature_detected!("avx512f") {
861                return unsafe { cosine_similarity_avx512_impl(a, b) };
862            }
863        }
864        if is_x86_feature_detected!("avx2") {
865            unsafe { cosine_similarity_avx2_impl(a, b) }
866        } else {
867            cosine_similarity_scalar(a, b)
868        }
869    }
870
871    #[cfg(target_arch = "aarch64")]
872    {
873        if a.len() >= 64 {
874            unsafe { cosine_similarity_neon_unrolled_impl(a, b) }
875        } else {
876            unsafe { cosine_similarity_neon_impl(a, b) }
877        }
878    }
879
880    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
881    {
882        cosine_similarity_scalar(a, b)
883    }
884}
885
886/// Legacy alias for backward compatibility
887#[inline(always)]
888pub fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
889    cosine_similarity_simd(a, b)
890}
891
892/// SIMD-optimized Manhattan distance
893/// Uses AVX-512 > AVX2 on x86_64, NEON on ARM64/Apple Silicon, scalar on other platforms
894#[inline(always)]
895pub fn manhattan_distance_simd(a: &[f32], b: &[f32]) -> f32 {
896    #[cfg(target_arch = "x86_64")]
897    {
898        #[cfg(feature = "simd-avx512")]
899        {
900            if is_x86_feature_detected!("avx512f") {
901                return unsafe { manhattan_distance_avx512_impl(a, b) };
902            }
903        }
904        if is_x86_feature_detected!("avx2") {
905            unsafe { manhattan_distance_avx2_impl(a, b) }
906        } else {
907            manhattan_distance_scalar(a, b)
908        }
909    }
910
911    #[cfg(target_arch = "aarch64")]
912    {
913        if a.len() >= 64 {
914            unsafe { manhattan_distance_neon_unrolled_impl(a, b) }
915        } else {
916            unsafe { manhattan_distance_neon_impl(a, b) }
917        }
918    }
919
920    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
921    {
922        manhattan_distance_scalar(a, b)
923    }
924}
925
926#[cfg(target_arch = "x86_64")]
927#[target_feature(enable = "avx2")]
928unsafe fn cosine_similarity_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
929    // SECURITY: Ensure both arrays have the same length to prevent out-of-bounds access
930    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
931
932    let len = a.len();
933    let mut dot = _mm256_setzero_ps();
934    let mut norm_a = _mm256_setzero_ps();
935    let mut norm_b = _mm256_setzero_ps();
936
937    let chunks = len / 8;
938    for i in 0..chunks {
939        let idx = i * 8;
940        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
941        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
942
943        // Dot product
944        dot = _mm256_add_ps(dot, _mm256_mul_ps(va, vb));
945
946        // Norms
947        norm_a = _mm256_add_ps(norm_a, _mm256_mul_ps(va, va));
948        norm_b = _mm256_add_ps(norm_b, _mm256_mul_ps(vb, vb));
949    }
950
951    let dot_arr: [f32; 8] = std::mem::transmute(dot);
952    let norm_a_arr: [f32; 8] = std::mem::transmute(norm_a);
953    let norm_b_arr: [f32; 8] = std::mem::transmute(norm_b);
954
955    let mut dot_sum = dot_arr.iter().sum::<f32>();
956    let mut norm_a_sum = norm_a_arr.iter().sum::<f32>();
957    let mut norm_b_sum = norm_b_arr.iter().sum::<f32>();
958
959    for i in (chunks * 8)..len {
960        dot_sum += a[i] * b[i];
961        norm_a_sum += a[i] * a[i];
962        norm_b_sum += b[i] * b[i];
963    }
964
965    dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
966}
967
968/// AVX2 Manhattan distance — processes 8 floats per iteration with absolute difference
969#[cfg(target_arch = "x86_64")]
970#[target_feature(enable = "avx2")]
971unsafe fn manhattan_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
972    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
973
974    let len = a.len();
975    // Use sign-bit mask for absolute value: clear the sign bit
976    let sign_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
977    let mut sum0 = _mm256_setzero_ps();
978    let mut sum1 = _mm256_setzero_ps();
979
980    // Process 16 floats at a time (2 x 8) for better ILP
981    let chunks = len / 16;
982    for i in 0..chunks {
983        let idx = i * 16;
984
985        let va0 = _mm256_loadu_ps(a.as_ptr().add(idx));
986        let vb0 = _mm256_loadu_ps(b.as_ptr().add(idx));
987        let diff0 = _mm256_sub_ps(va0, vb0);
988        let abs0 = _mm256_and_ps(diff0, sign_mask);
989        sum0 = _mm256_add_ps(sum0, abs0);
990
991        let va1 = _mm256_loadu_ps(a.as_ptr().add(idx + 8));
992        let vb1 = _mm256_loadu_ps(b.as_ptr().add(idx + 8));
993        let diff1 = _mm256_sub_ps(va1, vb1);
994        let abs1 = _mm256_and_ps(diff1, sign_mask);
995        sum1 = _mm256_add_ps(sum1, abs1);
996    }
997
998    let mut sum = _mm256_add_ps(sum0, sum1);
999
1000    // Process remaining 8-float chunks
1001    let remaining_start = chunks * 16;
1002    let remaining_chunks = (len - remaining_start) / 8;
1003    for i in 0..remaining_chunks {
1004        let idx = remaining_start + i * 8;
1005        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
1006        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
1007        let diff = _mm256_sub_ps(va, vb);
1008        let abs_diff = _mm256_and_ps(diff, sign_mask);
1009        sum = _mm256_add_ps(sum, abs_diff);
1010    }
1011
1012    // Horizontal sum
1013    let sum_arr: [f32; 8] = std::mem::transmute(sum);
1014    let mut total = sum_arr.iter().sum::<f32>();
1015
1016    // Handle remaining elements
1017    let scalar_start = remaining_start + remaining_chunks * 8;
1018    for i in scalar_start..len {
1019        total += (a[i] - b[i]).abs();
1020    }
1021
1022    total
1023}
1024
1025// Scalar fallback implementations
1026// These are kept for architectures without SIMD support
1027
1028#[allow(dead_code)]
1029fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
1030    a.iter()
1031        .zip(b.iter())
1032        .map(|(x, y)| {
1033            let diff = x - y;
1034            diff * diff
1035        })
1036        .sum::<f32>()
1037        .sqrt()
1038}
1039
1040#[allow(dead_code)]
1041fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
1042    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
1043}
1044
1045#[allow(dead_code)]
1046fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
1047    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
1048    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
1049    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
1050    dot / (norm_a * norm_b)
1051}
1052
1053#[allow(dead_code)]
1054fn manhattan_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
1055    a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
1056}
1057
1058// ============================================================================
1059// INT8 Quantized Operations
1060// ============================================================================
1061
1062/// SIMD-accelerated dot product for INT8 quantized vectors
1063/// Uses NEON vdotq_s32 on ARM64, AVX2 _mm256_maddubs_epi16 on x86_64
1064#[inline(always)]
1065pub fn dot_product_i8(a: &[i8], b: &[i8]) -> i32 {
1066    #[cfg(target_arch = "x86_64")]
1067    {
1068        if is_x86_feature_detected!("avx2") {
1069            unsafe { dot_product_i8_avx2_impl(a, b) }
1070        } else {
1071            dot_product_i8_scalar(a, b)
1072        }
1073    }
1074
1075    #[cfg(target_arch = "aarch64")]
1076    {
1077        unsafe { dot_product_i8_neon_impl(a, b) }
1078    }
1079
1080    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1081    {
1082        dot_product_i8_scalar(a, b)
1083    }
1084}
1085
1086/// SIMD-accelerated euclidean distance squared for INT8 quantized vectors
1087/// Returns squared distance (caller should sqrt if needed)
1088#[inline(always)]
1089pub fn euclidean_distance_squared_i8(a: &[i8], b: &[i8]) -> i32 {
1090    #[cfg(target_arch = "x86_64")]
1091    {
1092        if is_x86_feature_detected!("avx2") {
1093            unsafe { euclidean_distance_squared_i8_avx2_impl(a, b) }
1094        } else {
1095            euclidean_distance_squared_i8_scalar(a, b)
1096        }
1097    }
1098
1099    #[cfg(target_arch = "aarch64")]
1100    {
1101        unsafe { euclidean_distance_squared_i8_neon_impl(a, b) }
1102    }
1103
1104    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1105    {
1106        euclidean_distance_squared_i8_scalar(a, b)
1107    }
1108}
1109
1110/// NEON INT8 dot product using stable intrinsics
1111/// Note: Uses sign extension and multiply-add instead of vdotq_s32 for stability
1112///
1113/// # Safety
1114/// Caller must ensure a.len() == b.len()
1115#[cfg(target_arch = "aarch64")]
1116#[inline(always)]
1117unsafe fn dot_product_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
1118    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1119
1120    let len = a.len();
1121    let a_ptr = a.as_ptr();
1122    let b_ptr = b.as_ptr();
1123
1124    let mut sum = vdupq_n_s32(0);
1125
1126    // Process 8 int8s at a time (extend to i16, multiply, accumulate)
1127    let chunks = len / 8;
1128    let mut idx = 0usize;
1129
1130    for _ in 0..chunks {
1131        let va = vld1_s8(a_ptr.add(idx));
1132        let vb = vld1_s8(b_ptr.add(idx));
1133
1134        // Sign-extend to i16
1135        let va_i16 = vmovl_s8(va);
1136        let vb_i16 = vmovl_s8(vb);
1137
1138        // Multiply i16 * i16
1139        let prod_lo = vmull_s16(vget_low_s16(va_i16), vget_low_s16(vb_i16));
1140        let prod_hi = vmull_s16(vget_high_s16(va_i16), vget_high_s16(vb_i16));
1141
1142        // Accumulate
1143        sum = vaddq_s32(sum, prod_lo);
1144        sum = vaddq_s32(sum, prod_hi);
1145
1146        idx += 8;
1147    }
1148
1149    // Horizontal sum
1150    let mut total = vaddvq_s32(sum);
1151
1152    // Handle remaining elements with bounds-check elimination
1153    for i in (chunks * 8)..len {
1154        total += (*a.get_unchecked(i) as i32) * (*b.get_unchecked(i) as i32);
1155    }
1156
1157    total
1158}
1159
1160/// NEON INT8 euclidean distance squared using stable intrinsics
1161///
1162/// # Safety
1163/// Caller must ensure a.len() == b.len()
1164#[cfg(target_arch = "aarch64")]
1165#[inline(always)]
1166unsafe fn euclidean_distance_squared_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
1167    debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1168
1169    let len = a.len();
1170    let a_ptr = a.as_ptr();
1171    let b_ptr = b.as_ptr();
1172
1173    let mut sum = vdupq_n_s32(0);
1174
1175    // Process 8 int8s at a time
1176    let chunks = len / 8;
1177    let mut idx = 0usize;
1178
1179    for _ in 0..chunks {
1180        let va = vld1_s8(a_ptr.add(idx));
1181        let vb = vld1_s8(b_ptr.add(idx));
1182
1183        // Sign-extend to i16
1184        let va_i16 = vmovl_s8(va);
1185        let vb_i16 = vmovl_s8(vb);
1186
1187        // Compute difference in i16
1188        let diff = vsubq_s16(va_i16, vb_i16);
1189
1190        // Square and accumulate: diff^2
1191        let prod_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
1192        let prod_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
1193
1194        sum = vaddq_s32(sum, prod_lo);
1195        sum = vaddq_s32(sum, prod_hi);
1196
1197        idx += 8;
1198    }
1199
1200    let mut total = vaddvq_s32(sum);
1201
1202    // Handle remaining elements with bounds-check elimination
1203    for i in (chunks * 8)..len {
1204        let diff = (*a.get_unchecked(i) as i32) - (*b.get_unchecked(i) as i32);
1205        total += diff * diff;
1206    }
1207
1208    total
1209}
1210
1211/// AVX2 INT8 dot product
1212#[cfg(target_arch = "x86_64")]
1213#[target_feature(enable = "avx2")]
1214unsafe fn dot_product_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
1215    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1216
1217    let len = a.len();
1218    let mut sum = _mm256_setzero_si256();
1219
1220    // Process 32 int8s at a time
1221    let chunks = len / 32;
1222    for i in 0..chunks {
1223        let idx = i * 32;
1224        let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
1225        let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
1226
1227        // For signed int8 multiply, we need to extend to i16 first
1228        let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
1229        let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
1230        let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
1231        let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
1232
1233        let prod_lo = _mm256_madd_epi16(va_lo, vb_lo);
1234        let prod_hi = _mm256_madd_epi16(va_hi, vb_hi);
1235
1236        sum = _mm256_add_epi32(sum, prod_lo);
1237        sum = _mm256_add_epi32(sum, prod_hi);
1238    }
1239
1240    // Horizontal sum
1241    let sum_arr: [i32; 8] = std::mem::transmute(sum);
1242    let mut total: i32 = sum_arr.iter().sum();
1243
1244    // Handle remaining elements
1245    for i in (chunks * 32)..len {
1246        total += (a[i] as i32) * (b[i] as i32);
1247    }
1248
1249    total
1250}
1251
1252/// AVX2 INT8 euclidean distance squared
1253#[cfg(target_arch = "x86_64")]
1254#[target_feature(enable = "avx2")]
1255unsafe fn euclidean_distance_squared_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
1256    assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1257
1258    let len = a.len();
1259    let mut sum = _mm256_setzero_si256();
1260
1261    let chunks = len / 32;
1262    for i in 0..chunks {
1263        let idx = i * 32;
1264        let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
1265        let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
1266
1267        // Extend to i16, compute difference, then square
1268        let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
1269        let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
1270        let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
1271        let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
1272
1273        let diff_lo = _mm256_sub_epi16(va_lo, vb_lo);
1274        let diff_hi = _mm256_sub_epi16(va_hi, vb_hi);
1275
1276        let sq_lo = _mm256_madd_epi16(diff_lo, diff_lo);
1277        let sq_hi = _mm256_madd_epi16(diff_hi, diff_hi);
1278
1279        sum = _mm256_add_epi32(sum, sq_lo);
1280        sum = _mm256_add_epi32(sum, sq_hi);
1281    }
1282
1283    let sum_arr: [i32; 8] = std::mem::transmute(sum);
1284    let mut total: i32 = sum_arr.iter().sum();
1285
1286    for i in (chunks * 32)..len {
1287        let diff = (a[i] as i32) - (b[i] as i32);
1288        total += diff * diff;
1289    }
1290
1291    total
1292}
1293
1294/// Scalar fallback for INT8 dot product
1295#[allow(dead_code)]
1296fn dot_product_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
1297    a.iter()
1298        .zip(b.iter())
1299        .map(|(&x, &y)| (x as i32) * (y as i32))
1300        .sum()
1301}
1302
1303/// Scalar fallback for INT8 euclidean distance squared
1304#[allow(dead_code)]
1305fn euclidean_distance_squared_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
1306    a.iter()
1307        .zip(b.iter())
1308        .map(|(&x, &y)| {
1309            let diff = (x as i32) - (y as i32);
1310            diff * diff
1311        })
1312        .sum()
1313}
1314
1315// ============================================================================
1316// Batch Operations (Cache-optimized)
1317// ============================================================================
1318
1319/// Batch dot product - compute dot products of one query vector against multiple vectors
1320/// Returns results in the provided output slice
1321/// Optimized for cache locality by processing vectors in tiles
1322#[inline]
1323pub fn batch_dot_product(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1324    assert_eq!(
1325        vectors.len(),
1326        results.len(),
1327        "Output size must match vector count"
1328    );
1329
1330    // Process in tiles for better cache utilization
1331    const TILE_SIZE: usize = 16;
1332
1333    for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1334        let base_idx = chunk_idx * TILE_SIZE;
1335        for (i, vec) in chunk.iter().enumerate() {
1336            results[base_idx + i] = dot_product_simd(query, vec);
1337        }
1338    }
1339}
1340
1341/// Batch euclidean distance - compute distances from one query to multiple vectors
1342/// Returns results in the provided output slice
1343/// Optimized for cache locality
1344#[inline]
1345pub fn batch_euclidean(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1346    assert_eq!(
1347        vectors.len(),
1348        results.len(),
1349        "Output size must match vector count"
1350    );
1351
1352    const TILE_SIZE: usize = 16;
1353
1354    for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1355        let base_idx = chunk_idx * TILE_SIZE;
1356        for (i, vec) in chunk.iter().enumerate() {
1357            results[base_idx + i] = euclidean_distance_simd(query, vec);
1358        }
1359    }
1360}
1361
1362/// Batch cosine similarity - compute similarities from one query to multiple vectors
1363#[inline]
1364pub fn batch_cosine_similarity(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1365    assert_eq!(
1366        vectors.len(),
1367        results.len(),
1368        "Output size must match vector count"
1369    );
1370
1371    const TILE_SIZE: usize = 16;
1372
1373    for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1374        let base_idx = chunk_idx * TILE_SIZE;
1375        for (i, vec) in chunk.iter().enumerate() {
1376            results[base_idx + i] = cosine_similarity_simd(query, vec);
1377        }
1378    }
1379}
1380
1381/// Batch dot product with owned vectors (for convenience)
1382#[inline]
1383pub fn batch_dot_product_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
1384    let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1385    let mut results = vec![0.0; vectors.len()];
1386    batch_dot_product(query, &refs, &mut results);
1387    results
1388}
1389
1390/// Batch euclidean distance with owned vectors (for convenience)
1391#[inline]
1392pub fn batch_euclidean_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
1393    let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1394    let mut results = vec![0.0; vectors.len()];
1395    batch_euclidean(query, &refs, &mut results);
1396    results
1397}
1398
1399#[cfg(test)]
1400mod tests {
1401    use super::*;
1402
1403    #[test]
1404    fn test_euclidean_distance_simd() {
1405        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1406        let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
1407
1408        let result = euclidean_distance_simd(&a, &b);
1409        let expected = euclidean_distance_scalar(&a, &b);
1410
1411        assert!(
1412            (result - expected).abs() < 0.001,
1413            "SIMD result {} differs from scalar result {}",
1414            result,
1415            expected
1416        );
1417    }
1418
1419    #[test]
1420    fn test_euclidean_distance_large() {
1421        // Test with 128-dim vectors (common embedding size)
1422        let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
1423        let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
1424
1425        let result = euclidean_distance_simd(&a, &b);
1426        let expected = euclidean_distance_scalar(&a, &b);
1427
1428        assert!(
1429            (result - expected).abs() < 0.01,
1430            "Large vector: SIMD {} vs scalar {}",
1431            result,
1432            expected
1433        );
1434    }
1435
1436    #[test]
1437    fn test_dot_product_simd() {
1438        let a = vec![1.0; 16];
1439        let b = vec![2.0; 16];
1440
1441        let result = dot_product_simd(&a, &b);
1442        assert!((result - 32.0).abs() < 0.001);
1443    }
1444
1445    #[test]
1446    fn test_dot_product_large() {
1447        let a: Vec<f32> = (0..256).map(|i| (i % 10) as f32).collect();
1448        let b: Vec<f32> = (0..256).map(|i| ((i + 5) % 10) as f32).collect();
1449
1450        let result = dot_product_simd(&a, &b);
1451        let expected = dot_product_scalar(&a, &b);
1452
1453        assert!(
1454            (result - expected).abs() < 0.1,
1455            "Large dot product: SIMD {} vs scalar {}",
1456            result,
1457            expected
1458        );
1459    }
1460
1461    #[test]
1462    fn test_cosine_similarity_simd() {
1463        let a = vec![1.0, 0.0, 0.0];
1464        let b = vec![1.0, 0.0, 0.0];
1465
1466        let result = cosine_similarity_simd(&a, &b);
1467        assert!((result - 1.0).abs() < 0.001);
1468    }
1469
1470    #[test]
1471    fn test_cosine_similarity_orthogonal() {
1472        let a = vec![1.0, 0.0, 0.0, 0.0];
1473        let b = vec![0.0, 1.0, 0.0, 0.0];
1474
1475        let result = cosine_similarity_simd(&a, &b);
1476        assert!(
1477            result.abs() < 0.001,
1478            "Orthogonal vectors should have ~0 similarity, got {}",
1479            result
1480        );
1481    }
1482
1483    #[test]
1484    fn test_manhattan_distance_simd() {
1485        let a = vec![1.0, 2.0, 3.0, 4.0];
1486        let b = vec![5.0, 6.0, 7.0, 8.0];
1487
1488        let result = manhattan_distance_simd(&a, &b);
1489        let expected = manhattan_distance_scalar(&a, &b);
1490
1491        assert!(
1492            (result - expected).abs() < 0.001,
1493            "Manhattan: SIMD {} vs scalar {}",
1494            result,
1495            expected
1496        );
1497        assert!((result - 16.0).abs() < 0.001); // |4| + |4| + |4| + |4| = 16
1498    }
1499
1500    #[test]
1501    fn test_non_aligned_lengths() {
1502        // Test vectors not aligned to SIMD width (4 for NEON, 8 for AVX2)
1503        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]; // 7 elements
1504        let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1505
1506        let result = euclidean_distance_simd(&a, &b);
1507        let expected = euclidean_distance_scalar(&a, &b);
1508
1509        assert!(
1510            (result - expected).abs() < 0.001,
1511            "Non-aligned: SIMD {} vs scalar {}",
1512            result,
1513            expected
1514        );
1515    }
1516
1517    // Legacy function tests (ensure backward compatibility)
1518    #[test]
1519    fn test_legacy_avx2_aliases() {
1520        let a = vec![1.0, 2.0, 3.0, 4.0];
1521        let b = vec![5.0, 6.0, 7.0, 8.0];
1522
1523        // These should work identically to the _simd versions
1524        let _ = euclidean_distance_avx2(&a, &b);
1525        let _ = dot_product_avx2(&a, &b);
1526        let _ = cosine_similarity_avx2(&a, &b);
1527    }
1528
1529    // INT8 quantized operation tests
1530    #[test]
1531    fn test_dot_product_i8() {
1532        let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1533        let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
1534
1535        let result = dot_product_i8(&a, &b);
1536        let expected = dot_product_i8_scalar(&a, &b);
1537
1538        assert_eq!(
1539            result, expected,
1540            "INT8 dot product: SIMD {} vs scalar {}",
1541            result, expected
1542        );
1543    }
1544
1545    #[test]
1546    fn test_dot_product_i8_large() {
1547        // Test with 128 elements (common for quantized embeddings)
1548        let a: Vec<i8> = (0..128)
1549            .map(|i| ((i % 256) as i8).wrapping_sub(64))
1550            .collect();
1551        let b: Vec<i8> = (0..128)
1552            .map(|i| (((i + 10) % 256) as i8).wrapping_sub(64))
1553            .collect();
1554
1555        let result = dot_product_i8(&a, &b);
1556        let expected = dot_product_i8_scalar(&a, &b);
1557
1558        assert_eq!(
1559            result, expected,
1560            "Large INT8 dot product: SIMD {} vs scalar {}",
1561            result, expected
1562        );
1563    }
1564
1565    #[test]
1566    fn test_euclidean_distance_squared_i8() {
1567        let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1568        let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
1569
1570        let result = euclidean_distance_squared_i8(&a, &b);
1571        let expected = euclidean_distance_squared_i8_scalar(&a, &b);
1572
1573        assert_eq!(
1574            result, expected,
1575            "INT8 euclidean^2: SIMD {} vs scalar {}",
1576            result, expected
1577        );
1578        // Each diff is 1, so 16 diffs squared = 16
1579        assert_eq!(result, 16, "Expected 16, got {}", result);
1580    }
1581
1582    #[test]
1583    fn test_euclidean_distance_squared_i8_large() {
1584        let a: Vec<i8> = (0..128)
1585            .map(|i| ((i % 256) as i8).wrapping_sub(64))
1586            .collect();
1587        let b: Vec<i8> = (0..128)
1588            .map(|i| (((i + 5) % 256) as i8).wrapping_sub(64))
1589            .collect();
1590
1591        let result = euclidean_distance_squared_i8(&a, &b);
1592        let expected = euclidean_distance_squared_i8_scalar(&a, &b);
1593
1594        assert_eq!(
1595            result, expected,
1596            "Large INT8 euclidean^2: SIMD {} vs scalar {}",
1597            result, expected
1598        );
1599    }
1600
1601    // Batch operation tests
1602    #[test]
1603    fn test_batch_dot_product() {
1604        let query = vec![1.0, 2.0, 3.0, 4.0];
1605        let v1 = vec![1.0, 0.0, 0.0, 0.0];
1606        let v2 = vec![0.0, 1.0, 0.0, 0.0];
1607        let v3 = vec![0.0, 0.0, 1.0, 0.0];
1608        let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
1609        let mut results = vec![0.0; 3];
1610
1611        batch_dot_product(&query, &vectors, &mut results);
1612
1613        assert!((results[0] - 1.0).abs() < 0.001);
1614        assert!((results[1] - 2.0).abs() < 0.001);
1615        assert!((results[2] - 3.0).abs() < 0.001);
1616    }
1617
1618    #[test]
1619    fn test_batch_euclidean() {
1620        let query = vec![0.0, 0.0, 0.0, 0.0];
1621        let v1 = vec![3.0, 4.0, 0.0, 0.0];
1622        let v2 = vec![0.0, 0.0, 5.0, 12.0];
1623        let vectors: Vec<&[f32]> = vec![&v1, &v2];
1624        let mut results = vec![0.0; 2];
1625
1626        batch_euclidean(&query, &vectors, &mut results);
1627
1628        assert!(
1629            (results[0] - 5.0).abs() < 0.001,
1630            "Expected 5.0, got {}",
1631            results[0]
1632        );
1633        assert!(
1634            (results[1] - 13.0).abs() < 0.001,
1635            "Expected 13.0, got {}",
1636            results[1]
1637        );
1638    }
1639
1640    #[test]
1641    fn test_batch_cosine_similarity() {
1642        let query = vec![1.0, 0.0, 0.0, 0.0];
1643        let v1 = vec![1.0, 0.0, 0.0, 0.0]; // Same direction
1644        let v2 = vec![0.0, 1.0, 0.0, 0.0]; // Orthogonal
1645        let v3 = vec![-1.0, 0.0, 0.0, 0.0]; // Opposite
1646        let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
1647        let mut results = vec![0.0; 3];
1648
1649        batch_cosine_similarity(&query, &vectors, &mut results);
1650
1651        assert!(
1652            (results[0] - 1.0).abs() < 0.001,
1653            "Same direction should be 1.0"
1654        );
1655        assert!(results[1].abs() < 0.001, "Orthogonal should be 0.0");
1656        assert!((results[2] + 1.0).abs() < 0.001, "Opposite should be -1.0");
1657    }
1658
1659    #[test]
1660    fn test_batch_owned_convenience() {
1661        let query = vec![1.0, 2.0, 3.0, 4.0];
1662        let vectors = vec![vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]];
1663
1664        let results = batch_dot_product_owned(&query, &vectors);
1665        assert_eq!(results.len(), 2);
1666        assert!((results[0] - 1.0).abs() < 0.001);
1667        assert!((results[1] - 2.0).abs() < 0.001);
1668    }
1669
1670    #[test]
1671    fn test_unrolled_vs_non_unrolled_consistency() {
1672        // Test that unrolled and non-unrolled implementations produce same results
1673        let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
1674        let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
1675
1676        let result = euclidean_distance_simd(&a, &b);
1677        let expected = euclidean_distance_scalar(&a, &b);
1678
1679        assert!(
1680            (result - expected).abs() < 0.01,
1681            "Unrolled consistency: SIMD {} vs scalar {}",
1682            result,
1683            expected
1684        );
1685    }
1686}