1#[cfg(target_arch = "x86_64")]
25use std::arch::x86_64::*;
26
27#[cfg(target_arch = "aarch64")]
28use std::arch::aarch64::*;
29
30#[allow(dead_code)]
32const PREFETCH_DISTANCE: usize = 64;
33
34#[inline(always)]
42pub fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
43 #[cfg(target_arch = "x86_64")]
44 {
45 #[cfg(feature = "simd-avx512")]
46 {
47 if is_x86_feature_detected!("avx512f") {
48 return unsafe { euclidean_distance_avx512_impl(a, b) };
49 }
50 }
51 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
52 unsafe { euclidean_distance_avx2_fma_impl(a, b) }
53 } else if is_x86_feature_detected!("avx2") {
54 unsafe { euclidean_distance_avx2_impl(a, b) }
55 } else {
56 euclidean_distance_scalar(a, b)
57 }
58 }
59
60 #[cfg(target_arch = "aarch64")]
61 {
62 if a.len() >= 64 {
64 unsafe { euclidean_distance_neon_unrolled_impl(a, b) }
65 } else {
66 unsafe { euclidean_distance_neon_impl(a, b) }
67 }
68 }
69
70 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
71 {
72 euclidean_distance_scalar(a, b)
73 }
74}
75
76#[inline(always)]
78pub fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
79 euclidean_distance_simd(a, b)
80}
81
82#[cfg(target_arch = "x86_64")]
83#[target_feature(enable = "avx2")]
84unsafe fn euclidean_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
85 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
87
88 let len = a.len();
89 let mut sum = _mm256_setzero_ps();
90
91 let chunks = len / 8;
93 for i in 0..chunks {
94 let idx = i * 8;
95
96 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
98 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
99
100 let diff = _mm256_sub_ps(va, vb);
102
103 let sq = _mm256_mul_ps(diff, diff);
105
106 sum = _mm256_add_ps(sum, sq);
108 }
109
110 let sum_arr: [f32; 8] = std::mem::transmute(sum);
112 let mut total = sum_arr.iter().sum::<f32>();
113
114 for i in (chunks * 8)..len {
116 let diff = a[i] - b[i];
117 total += diff * diff;
118 }
119
120 total.sqrt()
121}
122
123#[cfg(target_arch = "x86_64")]
125#[target_feature(enable = "avx2", enable = "fma")]
126unsafe fn euclidean_distance_avx2_fma_impl(a: &[f32], b: &[f32]) -> f32 {
127 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
128
129 let len = a.len();
130 let mut sum0 = _mm256_setzero_ps();
132 let mut sum1 = _mm256_setzero_ps();
133 let mut sum2 = _mm256_setzero_ps();
134 let mut sum3 = _mm256_setzero_ps();
135
136 let chunks = len / 32;
138 for i in 0..chunks {
139 let idx = i * 32;
140
141 let va0 = _mm256_loadu_ps(a.as_ptr().add(idx));
143 let vb0 = _mm256_loadu_ps(b.as_ptr().add(idx));
144 let diff0 = _mm256_sub_ps(va0, vb0);
145 sum0 = _mm256_fmadd_ps(diff0, diff0, sum0);
146
147 let va1 = _mm256_loadu_ps(a.as_ptr().add(idx + 8));
148 let vb1 = _mm256_loadu_ps(b.as_ptr().add(idx + 8));
149 let diff1 = _mm256_sub_ps(va1, vb1);
150 sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
151
152 let va2 = _mm256_loadu_ps(a.as_ptr().add(idx + 16));
153 let vb2 = _mm256_loadu_ps(b.as_ptr().add(idx + 16));
154 let diff2 = _mm256_sub_ps(va2, vb2);
155 sum2 = _mm256_fmadd_ps(diff2, diff2, sum2);
156
157 let va3 = _mm256_loadu_ps(a.as_ptr().add(idx + 24));
158 let vb3 = _mm256_loadu_ps(b.as_ptr().add(idx + 24));
159 let diff3 = _mm256_sub_ps(va3, vb3);
160 sum3 = _mm256_fmadd_ps(diff3, diff3, sum3);
161 }
162
163 let sum01 = _mm256_add_ps(sum0, sum1);
165 let sum23 = _mm256_add_ps(sum2, sum3);
166 let sum = _mm256_add_ps(sum01, sum23);
167
168 let remaining_start = chunks * 32;
170 let remaining_chunks = (len - remaining_start) / 8;
171 let mut final_sum = sum;
172 for i in 0..remaining_chunks {
173 let idx = remaining_start + i * 8;
174 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
175 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
176 let diff = _mm256_sub_ps(va, vb);
177 final_sum = _mm256_fmadd_ps(diff, diff, final_sum);
178 }
179
180 let sum_arr: [f32; 8] = std::mem::transmute(final_sum);
182 let mut total = sum_arr.iter().sum::<f32>();
183
184 let scalar_start = remaining_start + remaining_chunks * 8;
186 for i in scalar_start..len {
187 let diff = a[i] - b[i];
188 total += diff * diff;
189 }
190
191 total.sqrt()
192}
193
194#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
200#[target_feature(enable = "avx512f")]
201unsafe fn euclidean_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
202 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
203
204 let len = a.len();
205 let mut sum = _mm512_setzero_ps();
206
207 let chunks = len / 16;
209 for i in 0..chunks {
210 let idx = i * 16;
211 let va = _mm512_loadu_ps(a.as_ptr().add(idx));
212 let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
213 let diff = _mm512_sub_ps(va, vb);
214 sum = _mm512_fmadd_ps(diff, diff, sum);
215 }
216
217 let mut total = _mm512_reduce_add_ps(sum);
219
220 for i in (chunks * 16)..len {
222 let diff = a[i] - b[i];
223 total += diff * diff;
224 }
225
226 total.sqrt()
227}
228
229#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
231#[target_feature(enable = "avx512f")]
232unsafe fn dot_product_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
233 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
234
235 let len = a.len();
236 let mut sum = _mm512_setzero_ps();
237
238 let chunks = len / 16;
239 for i in 0..chunks {
240 let idx = i * 16;
241 let va = _mm512_loadu_ps(a.as_ptr().add(idx));
242 let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
243 sum = _mm512_fmadd_ps(va, vb, sum);
244 }
245
246 let mut total = _mm512_reduce_add_ps(sum);
247
248 for i in (chunks * 16)..len {
249 total += a[i] * b[i];
250 }
251
252 total
253}
254
255#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
257#[target_feature(enable = "avx512f")]
258unsafe fn cosine_similarity_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
259 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
260
261 let len = a.len();
262 let mut dot = _mm512_setzero_ps();
263 let mut norm_a = _mm512_setzero_ps();
264 let mut norm_b = _mm512_setzero_ps();
265
266 let chunks = len / 16;
267 for i in 0..chunks {
268 let idx = i * 16;
269 let va = _mm512_loadu_ps(a.as_ptr().add(idx));
270 let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
271
272 dot = _mm512_fmadd_ps(va, vb, dot);
273 norm_a = _mm512_fmadd_ps(va, va, norm_a);
274 norm_b = _mm512_fmadd_ps(vb, vb, norm_b);
275 }
276
277 let mut dot_sum = _mm512_reduce_add_ps(dot);
278 let mut norm_a_sum = _mm512_reduce_add_ps(norm_a);
279 let mut norm_b_sum = _mm512_reduce_add_ps(norm_b);
280
281 for i in (chunks * 16)..len {
282 dot_sum += a[i] * b[i];
283 norm_a_sum += a[i] * a[i];
284 norm_b_sum += b[i] * b[i];
285 }
286
287 dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
288}
289
290#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
292#[target_feature(enable = "avx512f")]
293unsafe fn manhattan_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
294 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
295
296 let len = a.len();
297 let mut sum = _mm512_setzero_ps();
298
299 let chunks = len / 16;
300 for i in 0..chunks {
301 let idx = i * 16;
302 let va = _mm512_loadu_ps(a.as_ptr().add(idx));
303 let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
304 let diff = _mm512_sub_ps(va, vb);
305 let abs_diff = _mm512_abs_ps(diff);
306 sum = _mm512_add_ps(sum, abs_diff);
307 }
308
309 let mut total = _mm512_reduce_add_ps(sum);
310
311 for i in (chunks * 16)..len {
312 total += (a[i] - b[i]).abs();
313 }
314
315 total
316}
317
318#[cfg(target_arch = "aarch64")]
328#[inline(always)]
329#[allow(dead_code)]
330unsafe fn euclidean_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
331 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
332
333 let len = a.len();
334 let mut sum = vdupq_n_f32(0.0);
335
336 let a_ptr = a.as_ptr();
337 let b_ptr = b.as_ptr();
338
339 let chunks = len / 4;
341 let mut idx = 0usize;
342
343 for _ in 0..chunks {
344 let va = vld1q_f32(a_ptr.add(idx));
345 let vb = vld1q_f32(b_ptr.add(idx));
346
347 let diff = vsubq_f32(va, vb);
349
350 sum = vfmaq_f32(sum, diff, diff);
352
353 idx += 4;
354 }
355
356 let mut total = vaddvq_f32(sum);
358
359 for i in (chunks * 4)..len {
361 let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
362 total += diff * diff;
363 }
364
365 total.sqrt()
366}
367
368#[cfg(target_arch = "aarch64")]
373#[inline(always)]
374unsafe fn dot_product_neon_impl(a: &[f32], b: &[f32]) -> f32 {
375 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
376
377 let len = a.len();
378 let mut sum = vdupq_n_f32(0.0);
379
380 let a_ptr = a.as_ptr();
381 let b_ptr = b.as_ptr();
382
383 let chunks = len / 4;
384 let mut idx = 0usize;
385
386 for _ in 0..chunks {
387 let va = vld1q_f32(a_ptr.add(idx));
388 let vb = vld1q_f32(b_ptr.add(idx));
389
390 sum = vfmaq_f32(sum, va, vb);
392
393 idx += 4;
394 }
395
396 let mut total = vaddvq_f32(sum);
397
398 for i in (chunks * 4)..len {
400 total += *a.get_unchecked(i) * *b.get_unchecked(i);
401 }
402
403 total
404}
405
406#[cfg(target_arch = "aarch64")]
411#[inline(always)]
412unsafe fn cosine_similarity_neon_impl(a: &[f32], b: &[f32]) -> f32 {
413 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
414
415 let len = a.len();
416 let mut dot = vdupq_n_f32(0.0);
417 let mut norm_a = vdupq_n_f32(0.0);
418 let mut norm_b = vdupq_n_f32(0.0);
419
420 let a_ptr = a.as_ptr();
421 let b_ptr = b.as_ptr();
422
423 let chunks = len / 4;
424 let mut idx = 0usize;
425
426 for _ in 0..chunks {
427 let va = vld1q_f32(a_ptr.add(idx));
428 let vb = vld1q_f32(b_ptr.add(idx));
429
430 dot = vfmaq_f32(dot, va, vb);
432
433 norm_a = vfmaq_f32(norm_a, va, va);
435 norm_b = vfmaq_f32(norm_b, vb, vb);
436
437 idx += 4;
438 }
439
440 let mut dot_sum = vaddvq_f32(dot);
441 let mut norm_a_sum = vaddvq_f32(norm_a);
442 let mut norm_b_sum = vaddvq_f32(norm_b);
443
444 for i in (chunks * 4)..len {
446 let ai = *a.get_unchecked(i);
447 let bi = *b.get_unchecked(i);
448 dot_sum += ai * bi;
449 norm_a_sum += ai * ai;
450 norm_b_sum += bi * bi;
451 }
452
453 dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
454}
455
456#[cfg(target_arch = "aarch64")]
461#[inline(always)]
462unsafe fn manhattan_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
463 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
464
465 let len = a.len();
466 let mut sum = vdupq_n_f32(0.0);
467
468 let a_ptr = a.as_ptr();
469 let b_ptr = b.as_ptr();
470
471 let chunks = len / 4;
472 let mut idx = 0usize;
473
474 for _ in 0..chunks {
475 let va = vld1q_f32(a_ptr.add(idx));
476 let vb = vld1q_f32(b_ptr.add(idx));
477
478 let abs_diff = vabdq_f32(va, vb);
480 sum = vaddq_f32(sum, abs_diff);
481
482 idx += 4;
483 }
484
485 let mut total = vaddvq_f32(sum);
486
487 for i in (chunks * 4)..len {
489 total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
490 }
491
492 total
493}
494
495#[cfg(target_arch = "aarch64")]
506#[inline(always)]
507unsafe fn euclidean_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
508 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
509
510 let len = a.len();
511 let a_ptr = a.as_ptr();
512 let b_ptr = b.as_ptr();
513
514 let mut sum0 = vdupq_n_f32(0.0);
516 let mut sum1 = vdupq_n_f32(0.0);
517 let mut sum2 = vdupq_n_f32(0.0);
518 let mut sum3 = vdupq_n_f32(0.0);
519
520 let chunks = len / 16;
522 let mut idx = 0usize;
523
524 for _ in 0..chunks {
525 let va0 = vld1q_f32(a_ptr.add(idx));
527 let vb0 = vld1q_f32(b_ptr.add(idx));
528 let diff0 = vsubq_f32(va0, vb0);
529 sum0 = vfmaq_f32(sum0, diff0, diff0);
530
531 let va1 = vld1q_f32(a_ptr.add(idx + 4));
532 let vb1 = vld1q_f32(b_ptr.add(idx + 4));
533 let diff1 = vsubq_f32(va1, vb1);
534 sum1 = vfmaq_f32(sum1, diff1, diff1);
535
536 let va2 = vld1q_f32(a_ptr.add(idx + 8));
537 let vb2 = vld1q_f32(b_ptr.add(idx + 8));
538 let diff2 = vsubq_f32(va2, vb2);
539 sum2 = vfmaq_f32(sum2, diff2, diff2);
540
541 let va3 = vld1q_f32(a_ptr.add(idx + 12));
542 let vb3 = vld1q_f32(b_ptr.add(idx + 12));
543 let diff3 = vsubq_f32(va3, vb3);
544 sum3 = vfmaq_f32(sum3, diff3, diff3);
545
546 idx += 16;
547 }
548
549 let sum01 = vaddq_f32(sum0, sum1);
551 let sum23 = vaddq_f32(sum2, sum3);
552 let sum = vaddq_f32(sum01, sum23);
553
554 let remaining_start = chunks * 16;
556 let remaining_chunks = (len - remaining_start) / 4;
557 let mut final_sum = sum;
558
559 idx = remaining_start;
560 for _ in 0..remaining_chunks {
561 let va = vld1q_f32(a_ptr.add(idx));
562 let vb = vld1q_f32(b_ptr.add(idx));
563 let diff = vsubq_f32(va, vb);
564 final_sum = vfmaq_f32(final_sum, diff, diff);
565 idx += 4;
566 }
567
568 let mut total = vaddvq_f32(final_sum);
570
571 let scalar_start = remaining_start + remaining_chunks * 4;
573 for i in scalar_start..len {
574 let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
575 total += diff * diff;
576 }
577
578 total.sqrt()
579}
580
581#[cfg(target_arch = "aarch64")]
586#[inline(always)]
587unsafe fn dot_product_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
588 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
589
590 let len = a.len();
591 let a_ptr = a.as_ptr();
592 let b_ptr = b.as_ptr();
593
594 let mut sum0 = vdupq_n_f32(0.0);
595 let mut sum1 = vdupq_n_f32(0.0);
596 let mut sum2 = vdupq_n_f32(0.0);
597 let mut sum3 = vdupq_n_f32(0.0);
598
599 let chunks = len / 16;
600 let mut idx = 0usize;
601
602 for _ in 0..chunks {
603 let va0 = vld1q_f32(a_ptr.add(idx));
604 let vb0 = vld1q_f32(b_ptr.add(idx));
605 sum0 = vfmaq_f32(sum0, va0, vb0);
606
607 let va1 = vld1q_f32(a_ptr.add(idx + 4));
608 let vb1 = vld1q_f32(b_ptr.add(idx + 4));
609 sum1 = vfmaq_f32(sum1, va1, vb1);
610
611 let va2 = vld1q_f32(a_ptr.add(idx + 8));
612 let vb2 = vld1q_f32(b_ptr.add(idx + 8));
613 sum2 = vfmaq_f32(sum2, va2, vb2);
614
615 let va3 = vld1q_f32(a_ptr.add(idx + 12));
616 let vb3 = vld1q_f32(b_ptr.add(idx + 12));
617 sum3 = vfmaq_f32(sum3, va3, vb3);
618
619 idx += 16;
620 }
621
622 let sum01 = vaddq_f32(sum0, sum1);
624 let sum23 = vaddq_f32(sum2, sum3);
625 let sum = vaddq_f32(sum01, sum23);
626
627 let remaining_start = chunks * 16;
628 let remaining_chunks = (len - remaining_start) / 4;
629 let mut final_sum = sum;
630
631 idx = remaining_start;
632 for _ in 0..remaining_chunks {
633 let va = vld1q_f32(a_ptr.add(idx));
634 let vb = vld1q_f32(b_ptr.add(idx));
635 final_sum = vfmaq_f32(final_sum, va, vb);
636 idx += 4;
637 }
638
639 let mut total = vaddvq_f32(final_sum);
640
641 let scalar_start = remaining_start + remaining_chunks * 4;
643 for i in scalar_start..len {
644 total += *a.get_unchecked(i) * *b.get_unchecked(i);
645 }
646
647 total
648}
649
650#[cfg(target_arch = "aarch64")]
655#[inline(always)]
656unsafe fn cosine_similarity_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
657 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
658
659 let len = a.len();
660 let a_ptr = a.as_ptr();
661 let b_ptr = b.as_ptr();
662
663 let mut dot0 = vdupq_n_f32(0.0);
664 let mut dot1 = vdupq_n_f32(0.0);
665 let mut norm_a0 = vdupq_n_f32(0.0);
666 let mut norm_a1 = vdupq_n_f32(0.0);
667 let mut norm_b0 = vdupq_n_f32(0.0);
668 let mut norm_b1 = vdupq_n_f32(0.0);
669
670 let chunks = len / 8;
671 let mut idx = 0usize;
672
673 for _ in 0..chunks {
674 let va0 = vld1q_f32(a_ptr.add(idx));
675 let vb0 = vld1q_f32(b_ptr.add(idx));
676 dot0 = vfmaq_f32(dot0, va0, vb0);
677 norm_a0 = vfmaq_f32(norm_a0, va0, va0);
678 norm_b0 = vfmaq_f32(norm_b0, vb0, vb0);
679
680 let va1 = vld1q_f32(a_ptr.add(idx + 4));
681 let vb1 = vld1q_f32(b_ptr.add(idx + 4));
682 dot1 = vfmaq_f32(dot1, va1, vb1);
683 norm_a1 = vfmaq_f32(norm_a1, va1, va1);
684 norm_b1 = vfmaq_f32(norm_b1, vb1, vb1);
685
686 idx += 8;
687 }
688
689 let dot = vaddq_f32(dot0, dot1);
691 let norm_a = vaddq_f32(norm_a0, norm_a1);
692 let norm_b = vaddq_f32(norm_b0, norm_b1);
693
694 let mut dot_sum = vaddvq_f32(dot);
695 let mut norm_a_sum = vaddvq_f32(norm_a);
696 let mut norm_b_sum = vaddvq_f32(norm_b);
697
698 for i in (chunks * 8)..len {
700 let ai = *a.get_unchecked(i);
701 let bi = *b.get_unchecked(i);
702 dot_sum += ai * bi;
703 norm_a_sum += ai * ai;
704 norm_b_sum += bi * bi;
705 }
706
707 dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
708}
709
710#[cfg(target_arch = "aarch64")]
715#[inline(always)]
716unsafe fn manhattan_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
717 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
718
719 let len = a.len();
720 let a_ptr = a.as_ptr();
721 let b_ptr = b.as_ptr();
722
723 let mut sum0 = vdupq_n_f32(0.0);
724 let mut sum1 = vdupq_n_f32(0.0);
725 let mut sum2 = vdupq_n_f32(0.0);
726 let mut sum3 = vdupq_n_f32(0.0);
727
728 let chunks = len / 16;
729 let mut idx = 0usize;
730
731 for _ in 0..chunks {
732 let va0 = vld1q_f32(a_ptr.add(idx));
734 let vb0 = vld1q_f32(b_ptr.add(idx));
735 sum0 = vaddq_f32(sum0, vabdq_f32(va0, vb0));
736
737 let va1 = vld1q_f32(a_ptr.add(idx + 4));
738 let vb1 = vld1q_f32(b_ptr.add(idx + 4));
739 sum1 = vaddq_f32(sum1, vabdq_f32(va1, vb1));
740
741 let va2 = vld1q_f32(a_ptr.add(idx + 8));
742 let vb2 = vld1q_f32(b_ptr.add(idx + 8));
743 sum2 = vaddq_f32(sum2, vabdq_f32(va2, vb2));
744
745 let va3 = vld1q_f32(a_ptr.add(idx + 12));
746 let vb3 = vld1q_f32(b_ptr.add(idx + 12));
747 sum3 = vaddq_f32(sum3, vabdq_f32(va3, vb3));
748
749 idx += 16;
750 }
751
752 let sum01 = vaddq_f32(sum0, sum1);
754 let sum23 = vaddq_f32(sum2, sum3);
755 let sum = vaddq_f32(sum01, sum23);
756
757 let remaining_start = chunks * 16;
758 let remaining_chunks = (len - remaining_start) / 4;
759 let mut final_sum = sum;
760
761 idx = remaining_start;
762 for _ in 0..remaining_chunks {
763 let va = vld1q_f32(a_ptr.add(idx));
764 let vb = vld1q_f32(b_ptr.add(idx));
765 final_sum = vaddq_f32(final_sum, vabdq_f32(va, vb));
766 idx += 4;
767 }
768
769 let mut total = vaddvq_f32(final_sum);
770
771 let scalar_start = remaining_start + remaining_chunks * 4;
773 for i in scalar_start..len {
774 total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
775 }
776
777 total
778}
779
780#[inline(always)]
787pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
788 #[cfg(target_arch = "x86_64")]
789 {
790 #[cfg(feature = "simd-avx512")]
791 {
792 if is_x86_feature_detected!("avx512f") {
793 return unsafe { dot_product_avx512_impl(a, b) };
794 }
795 }
796 if is_x86_feature_detected!("avx2") {
797 unsafe { dot_product_avx2_impl(a, b) }
798 } else {
799 dot_product_scalar(a, b)
800 }
801 }
802
803 #[cfg(target_arch = "aarch64")]
804 {
805 if a.len() >= 64 {
806 unsafe { dot_product_neon_unrolled_impl(a, b) }
807 } else {
808 unsafe { dot_product_neon_impl(a, b) }
809 }
810 }
811
812 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
813 {
814 dot_product_scalar(a, b)
815 }
816}
817
818#[inline(always)]
820pub fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
821 dot_product_simd(a, b)
822}
823
824#[cfg(target_arch = "x86_64")]
825#[target_feature(enable = "avx2")]
826unsafe fn dot_product_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
827 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
829
830 let len = a.len();
831 let mut sum = _mm256_setzero_ps();
832
833 let chunks = len / 8;
834 for i in 0..chunks {
835 let idx = i * 8;
836 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
837 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
838 let prod = _mm256_mul_ps(va, vb);
839 sum = _mm256_add_ps(sum, prod);
840 }
841
842 let sum_arr: [f32; 8] = std::mem::transmute(sum);
843 let mut total = sum_arr.iter().sum::<f32>();
844
845 for i in (chunks * 8)..len {
846 total += a[i] * b[i];
847 }
848
849 total
850}
851
852#[inline(always)]
855pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
856 #[cfg(target_arch = "x86_64")]
857 {
858 #[cfg(feature = "simd-avx512")]
859 {
860 if is_x86_feature_detected!("avx512f") {
861 return unsafe { cosine_similarity_avx512_impl(a, b) };
862 }
863 }
864 if is_x86_feature_detected!("avx2") {
865 unsafe { cosine_similarity_avx2_impl(a, b) }
866 } else {
867 cosine_similarity_scalar(a, b)
868 }
869 }
870
871 #[cfg(target_arch = "aarch64")]
872 {
873 if a.len() >= 64 {
874 unsafe { cosine_similarity_neon_unrolled_impl(a, b) }
875 } else {
876 unsafe { cosine_similarity_neon_impl(a, b) }
877 }
878 }
879
880 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
881 {
882 cosine_similarity_scalar(a, b)
883 }
884}
885
886#[inline(always)]
888pub fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
889 cosine_similarity_simd(a, b)
890}
891
892#[inline(always)]
895pub fn manhattan_distance_simd(a: &[f32], b: &[f32]) -> f32 {
896 #[cfg(target_arch = "x86_64")]
897 {
898 #[cfg(feature = "simd-avx512")]
899 {
900 if is_x86_feature_detected!("avx512f") {
901 return unsafe { manhattan_distance_avx512_impl(a, b) };
902 }
903 }
904 if is_x86_feature_detected!("avx2") {
905 unsafe { manhattan_distance_avx2_impl(a, b) }
906 } else {
907 manhattan_distance_scalar(a, b)
908 }
909 }
910
911 #[cfg(target_arch = "aarch64")]
912 {
913 if a.len() >= 64 {
914 unsafe { manhattan_distance_neon_unrolled_impl(a, b) }
915 } else {
916 unsafe { manhattan_distance_neon_impl(a, b) }
917 }
918 }
919
920 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
921 {
922 manhattan_distance_scalar(a, b)
923 }
924}
925
926#[cfg(target_arch = "x86_64")]
927#[target_feature(enable = "avx2")]
928unsafe fn cosine_similarity_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
929 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
931
932 let len = a.len();
933 let mut dot = _mm256_setzero_ps();
934 let mut norm_a = _mm256_setzero_ps();
935 let mut norm_b = _mm256_setzero_ps();
936
937 let chunks = len / 8;
938 for i in 0..chunks {
939 let idx = i * 8;
940 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
941 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
942
943 dot = _mm256_add_ps(dot, _mm256_mul_ps(va, vb));
945
946 norm_a = _mm256_add_ps(norm_a, _mm256_mul_ps(va, va));
948 norm_b = _mm256_add_ps(norm_b, _mm256_mul_ps(vb, vb));
949 }
950
951 let dot_arr: [f32; 8] = std::mem::transmute(dot);
952 let norm_a_arr: [f32; 8] = std::mem::transmute(norm_a);
953 let norm_b_arr: [f32; 8] = std::mem::transmute(norm_b);
954
955 let mut dot_sum = dot_arr.iter().sum::<f32>();
956 let mut norm_a_sum = norm_a_arr.iter().sum::<f32>();
957 let mut norm_b_sum = norm_b_arr.iter().sum::<f32>();
958
959 for i in (chunks * 8)..len {
960 dot_sum += a[i] * b[i];
961 norm_a_sum += a[i] * a[i];
962 norm_b_sum += b[i] * b[i];
963 }
964
965 dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
966}
967
968#[cfg(target_arch = "x86_64")]
970#[target_feature(enable = "avx2")]
971unsafe fn manhattan_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
972 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
973
974 let len = a.len();
975 let sign_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
977 let mut sum0 = _mm256_setzero_ps();
978 let mut sum1 = _mm256_setzero_ps();
979
980 let chunks = len / 16;
982 for i in 0..chunks {
983 let idx = i * 16;
984
985 let va0 = _mm256_loadu_ps(a.as_ptr().add(idx));
986 let vb0 = _mm256_loadu_ps(b.as_ptr().add(idx));
987 let diff0 = _mm256_sub_ps(va0, vb0);
988 let abs0 = _mm256_and_ps(diff0, sign_mask);
989 sum0 = _mm256_add_ps(sum0, abs0);
990
991 let va1 = _mm256_loadu_ps(a.as_ptr().add(idx + 8));
992 let vb1 = _mm256_loadu_ps(b.as_ptr().add(idx + 8));
993 let diff1 = _mm256_sub_ps(va1, vb1);
994 let abs1 = _mm256_and_ps(diff1, sign_mask);
995 sum1 = _mm256_add_ps(sum1, abs1);
996 }
997
998 let mut sum = _mm256_add_ps(sum0, sum1);
999
1000 let remaining_start = chunks * 16;
1002 let remaining_chunks = (len - remaining_start) / 8;
1003 for i in 0..remaining_chunks {
1004 let idx = remaining_start + i * 8;
1005 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
1006 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
1007 let diff = _mm256_sub_ps(va, vb);
1008 let abs_diff = _mm256_and_ps(diff, sign_mask);
1009 sum = _mm256_add_ps(sum, abs_diff);
1010 }
1011
1012 let sum_arr: [f32; 8] = std::mem::transmute(sum);
1014 let mut total = sum_arr.iter().sum::<f32>();
1015
1016 let scalar_start = remaining_start + remaining_chunks * 8;
1018 for i in scalar_start..len {
1019 total += (a[i] - b[i]).abs();
1020 }
1021
1022 total
1023}
1024
1025#[allow(dead_code)]
1029fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
1030 a.iter()
1031 .zip(b.iter())
1032 .map(|(x, y)| {
1033 let diff = x - y;
1034 diff * diff
1035 })
1036 .sum::<f32>()
1037 .sqrt()
1038}
1039
1040#[allow(dead_code)]
1041fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
1042 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
1043}
1044
1045#[allow(dead_code)]
1046fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
1047 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
1048 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
1049 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
1050 dot / (norm_a * norm_b)
1051}
1052
1053#[allow(dead_code)]
1054fn manhattan_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
1055 a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
1056}
1057
1058#[inline(always)]
1065pub fn dot_product_i8(a: &[i8], b: &[i8]) -> i32 {
1066 #[cfg(target_arch = "x86_64")]
1067 {
1068 if is_x86_feature_detected!("avx2") {
1069 unsafe { dot_product_i8_avx2_impl(a, b) }
1070 } else {
1071 dot_product_i8_scalar(a, b)
1072 }
1073 }
1074
1075 #[cfg(target_arch = "aarch64")]
1076 {
1077 unsafe { dot_product_i8_neon_impl(a, b) }
1078 }
1079
1080 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1081 {
1082 dot_product_i8_scalar(a, b)
1083 }
1084}
1085
1086#[inline(always)]
1089pub fn euclidean_distance_squared_i8(a: &[i8], b: &[i8]) -> i32 {
1090 #[cfg(target_arch = "x86_64")]
1091 {
1092 if is_x86_feature_detected!("avx2") {
1093 unsafe { euclidean_distance_squared_i8_avx2_impl(a, b) }
1094 } else {
1095 euclidean_distance_squared_i8_scalar(a, b)
1096 }
1097 }
1098
1099 #[cfg(target_arch = "aarch64")]
1100 {
1101 unsafe { euclidean_distance_squared_i8_neon_impl(a, b) }
1102 }
1103
1104 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1105 {
1106 euclidean_distance_squared_i8_scalar(a, b)
1107 }
1108}
1109
1110#[cfg(target_arch = "aarch64")]
1116#[inline(always)]
1117unsafe fn dot_product_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
1118 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1119
1120 let len = a.len();
1121 let a_ptr = a.as_ptr();
1122 let b_ptr = b.as_ptr();
1123
1124 let mut sum = vdupq_n_s32(0);
1125
1126 let chunks = len / 8;
1128 let mut idx = 0usize;
1129
1130 for _ in 0..chunks {
1131 let va = vld1_s8(a_ptr.add(idx));
1132 let vb = vld1_s8(b_ptr.add(idx));
1133
1134 let va_i16 = vmovl_s8(va);
1136 let vb_i16 = vmovl_s8(vb);
1137
1138 let prod_lo = vmull_s16(vget_low_s16(va_i16), vget_low_s16(vb_i16));
1140 let prod_hi = vmull_s16(vget_high_s16(va_i16), vget_high_s16(vb_i16));
1141
1142 sum = vaddq_s32(sum, prod_lo);
1144 sum = vaddq_s32(sum, prod_hi);
1145
1146 idx += 8;
1147 }
1148
1149 let mut total = vaddvq_s32(sum);
1151
1152 for i in (chunks * 8)..len {
1154 total += (*a.get_unchecked(i) as i32) * (*b.get_unchecked(i) as i32);
1155 }
1156
1157 total
1158}
1159
1160#[cfg(target_arch = "aarch64")]
1165#[inline(always)]
1166unsafe fn euclidean_distance_squared_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
1167 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1168
1169 let len = a.len();
1170 let a_ptr = a.as_ptr();
1171 let b_ptr = b.as_ptr();
1172
1173 let mut sum = vdupq_n_s32(0);
1174
1175 let chunks = len / 8;
1177 let mut idx = 0usize;
1178
1179 for _ in 0..chunks {
1180 let va = vld1_s8(a_ptr.add(idx));
1181 let vb = vld1_s8(b_ptr.add(idx));
1182
1183 let va_i16 = vmovl_s8(va);
1185 let vb_i16 = vmovl_s8(vb);
1186
1187 let diff = vsubq_s16(va_i16, vb_i16);
1189
1190 let prod_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
1192 let prod_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
1193
1194 sum = vaddq_s32(sum, prod_lo);
1195 sum = vaddq_s32(sum, prod_hi);
1196
1197 idx += 8;
1198 }
1199
1200 let mut total = vaddvq_s32(sum);
1201
1202 for i in (chunks * 8)..len {
1204 let diff = (*a.get_unchecked(i) as i32) - (*b.get_unchecked(i) as i32);
1205 total += diff * diff;
1206 }
1207
1208 total
1209}
1210
1211#[cfg(target_arch = "x86_64")]
1213#[target_feature(enable = "avx2")]
1214unsafe fn dot_product_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
1215 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1216
1217 let len = a.len();
1218 let mut sum = _mm256_setzero_si256();
1219
1220 let chunks = len / 32;
1222 for i in 0..chunks {
1223 let idx = i * 32;
1224 let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
1225 let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
1226
1227 let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
1229 let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
1230 let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
1231 let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
1232
1233 let prod_lo = _mm256_madd_epi16(va_lo, vb_lo);
1234 let prod_hi = _mm256_madd_epi16(va_hi, vb_hi);
1235
1236 sum = _mm256_add_epi32(sum, prod_lo);
1237 sum = _mm256_add_epi32(sum, prod_hi);
1238 }
1239
1240 let sum_arr: [i32; 8] = std::mem::transmute(sum);
1242 let mut total: i32 = sum_arr.iter().sum();
1243
1244 for i in (chunks * 32)..len {
1246 total += (a[i] as i32) * (b[i] as i32);
1247 }
1248
1249 total
1250}
1251
1252#[cfg(target_arch = "x86_64")]
1254#[target_feature(enable = "avx2")]
1255unsafe fn euclidean_distance_squared_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
1256 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1257
1258 let len = a.len();
1259 let mut sum = _mm256_setzero_si256();
1260
1261 let chunks = len / 32;
1262 for i in 0..chunks {
1263 let idx = i * 32;
1264 let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
1265 let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
1266
1267 let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
1269 let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
1270 let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
1271 let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
1272
1273 let diff_lo = _mm256_sub_epi16(va_lo, vb_lo);
1274 let diff_hi = _mm256_sub_epi16(va_hi, vb_hi);
1275
1276 let sq_lo = _mm256_madd_epi16(diff_lo, diff_lo);
1277 let sq_hi = _mm256_madd_epi16(diff_hi, diff_hi);
1278
1279 sum = _mm256_add_epi32(sum, sq_lo);
1280 sum = _mm256_add_epi32(sum, sq_hi);
1281 }
1282
1283 let sum_arr: [i32; 8] = std::mem::transmute(sum);
1284 let mut total: i32 = sum_arr.iter().sum();
1285
1286 for i in (chunks * 32)..len {
1287 let diff = (a[i] as i32) - (b[i] as i32);
1288 total += diff * diff;
1289 }
1290
1291 total
1292}
1293
1294#[allow(dead_code)]
1296fn dot_product_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
1297 a.iter()
1298 .zip(b.iter())
1299 .map(|(&x, &y)| (x as i32) * (y as i32))
1300 .sum()
1301}
1302
1303#[allow(dead_code)]
1305fn euclidean_distance_squared_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
1306 a.iter()
1307 .zip(b.iter())
1308 .map(|(&x, &y)| {
1309 let diff = (x as i32) - (y as i32);
1310 diff * diff
1311 })
1312 .sum()
1313}
1314
1315#[inline]
1323pub fn batch_dot_product(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1324 assert_eq!(
1325 vectors.len(),
1326 results.len(),
1327 "Output size must match vector count"
1328 );
1329
1330 const TILE_SIZE: usize = 16;
1332
1333 for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1334 let base_idx = chunk_idx * TILE_SIZE;
1335 for (i, vec) in chunk.iter().enumerate() {
1336 results[base_idx + i] = dot_product_simd(query, vec);
1337 }
1338 }
1339}
1340
1341#[inline]
1345pub fn batch_euclidean(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1346 assert_eq!(
1347 vectors.len(),
1348 results.len(),
1349 "Output size must match vector count"
1350 );
1351
1352 const TILE_SIZE: usize = 16;
1353
1354 for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1355 let base_idx = chunk_idx * TILE_SIZE;
1356 for (i, vec) in chunk.iter().enumerate() {
1357 results[base_idx + i] = euclidean_distance_simd(query, vec);
1358 }
1359 }
1360}
1361
1362#[inline]
1364pub fn batch_cosine_similarity(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1365 assert_eq!(
1366 vectors.len(),
1367 results.len(),
1368 "Output size must match vector count"
1369 );
1370
1371 const TILE_SIZE: usize = 16;
1372
1373 for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1374 let base_idx = chunk_idx * TILE_SIZE;
1375 for (i, vec) in chunk.iter().enumerate() {
1376 results[base_idx + i] = cosine_similarity_simd(query, vec);
1377 }
1378 }
1379}
1380
1381#[inline]
1383pub fn batch_dot_product_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
1384 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1385 let mut results = vec![0.0; vectors.len()];
1386 batch_dot_product(query, &refs, &mut results);
1387 results
1388}
1389
1390#[inline]
1392pub fn batch_euclidean_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
1393 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1394 let mut results = vec![0.0; vectors.len()];
1395 batch_euclidean(query, &refs, &mut results);
1396 results
1397}
1398
1399#[cfg(test)]
1400mod tests {
1401 use super::*;
1402
1403 #[test]
1404 fn test_euclidean_distance_simd() {
1405 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1406 let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
1407
1408 let result = euclidean_distance_simd(&a, &b);
1409 let expected = euclidean_distance_scalar(&a, &b);
1410
1411 assert!(
1412 (result - expected).abs() < 0.001,
1413 "SIMD result {} differs from scalar result {}",
1414 result,
1415 expected
1416 );
1417 }
1418
1419 #[test]
1420 fn test_euclidean_distance_large() {
1421 let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
1423 let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
1424
1425 let result = euclidean_distance_simd(&a, &b);
1426 let expected = euclidean_distance_scalar(&a, &b);
1427
1428 assert!(
1429 (result - expected).abs() < 0.01,
1430 "Large vector: SIMD {} vs scalar {}",
1431 result,
1432 expected
1433 );
1434 }
1435
1436 #[test]
1437 fn test_dot_product_simd() {
1438 let a = vec![1.0; 16];
1439 let b = vec![2.0; 16];
1440
1441 let result = dot_product_simd(&a, &b);
1442 assert!((result - 32.0).abs() < 0.001);
1443 }
1444
1445 #[test]
1446 fn test_dot_product_large() {
1447 let a: Vec<f32> = (0..256).map(|i| (i % 10) as f32).collect();
1448 let b: Vec<f32> = (0..256).map(|i| ((i + 5) % 10) as f32).collect();
1449
1450 let result = dot_product_simd(&a, &b);
1451 let expected = dot_product_scalar(&a, &b);
1452
1453 assert!(
1454 (result - expected).abs() < 0.1,
1455 "Large dot product: SIMD {} vs scalar {}",
1456 result,
1457 expected
1458 );
1459 }
1460
1461 #[test]
1462 fn test_cosine_similarity_simd() {
1463 let a = vec![1.0, 0.0, 0.0];
1464 let b = vec![1.0, 0.0, 0.0];
1465
1466 let result = cosine_similarity_simd(&a, &b);
1467 assert!((result - 1.0).abs() < 0.001);
1468 }
1469
1470 #[test]
1471 fn test_cosine_similarity_orthogonal() {
1472 let a = vec![1.0, 0.0, 0.0, 0.0];
1473 let b = vec![0.0, 1.0, 0.0, 0.0];
1474
1475 let result = cosine_similarity_simd(&a, &b);
1476 assert!(
1477 result.abs() < 0.001,
1478 "Orthogonal vectors should have ~0 similarity, got {}",
1479 result
1480 );
1481 }
1482
1483 #[test]
1484 fn test_manhattan_distance_simd() {
1485 let a = vec![1.0, 2.0, 3.0, 4.0];
1486 let b = vec![5.0, 6.0, 7.0, 8.0];
1487
1488 let result = manhattan_distance_simd(&a, &b);
1489 let expected = manhattan_distance_scalar(&a, &b);
1490
1491 assert!(
1492 (result - expected).abs() < 0.001,
1493 "Manhattan: SIMD {} vs scalar {}",
1494 result,
1495 expected
1496 );
1497 assert!((result - 16.0).abs() < 0.001); }
1499
1500 #[test]
1501 fn test_non_aligned_lengths() {
1502 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]; let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1505
1506 let result = euclidean_distance_simd(&a, &b);
1507 let expected = euclidean_distance_scalar(&a, &b);
1508
1509 assert!(
1510 (result - expected).abs() < 0.001,
1511 "Non-aligned: SIMD {} vs scalar {}",
1512 result,
1513 expected
1514 );
1515 }
1516
1517 #[test]
1519 fn test_legacy_avx2_aliases() {
1520 let a = vec![1.0, 2.0, 3.0, 4.0];
1521 let b = vec![5.0, 6.0, 7.0, 8.0];
1522
1523 let _ = euclidean_distance_avx2(&a, &b);
1525 let _ = dot_product_avx2(&a, &b);
1526 let _ = cosine_similarity_avx2(&a, &b);
1527 }
1528
1529 #[test]
1531 fn test_dot_product_i8() {
1532 let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1533 let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
1534
1535 let result = dot_product_i8(&a, &b);
1536 let expected = dot_product_i8_scalar(&a, &b);
1537
1538 assert_eq!(
1539 result, expected,
1540 "INT8 dot product: SIMD {} vs scalar {}",
1541 result, expected
1542 );
1543 }
1544
1545 #[test]
1546 fn test_dot_product_i8_large() {
1547 let a: Vec<i8> = (0..128)
1549 .map(|i| ((i % 256) as i8).wrapping_sub(64))
1550 .collect();
1551 let b: Vec<i8> = (0..128)
1552 .map(|i| (((i + 10) % 256) as i8).wrapping_sub(64))
1553 .collect();
1554
1555 let result = dot_product_i8(&a, &b);
1556 let expected = dot_product_i8_scalar(&a, &b);
1557
1558 assert_eq!(
1559 result, expected,
1560 "Large INT8 dot product: SIMD {} vs scalar {}",
1561 result, expected
1562 );
1563 }
1564
1565 #[test]
1566 fn test_euclidean_distance_squared_i8() {
1567 let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1568 let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
1569
1570 let result = euclidean_distance_squared_i8(&a, &b);
1571 let expected = euclidean_distance_squared_i8_scalar(&a, &b);
1572
1573 assert_eq!(
1574 result, expected,
1575 "INT8 euclidean^2: SIMD {} vs scalar {}",
1576 result, expected
1577 );
1578 assert_eq!(result, 16, "Expected 16, got {}", result);
1580 }
1581
1582 #[test]
1583 fn test_euclidean_distance_squared_i8_large() {
1584 let a: Vec<i8> = (0..128)
1585 .map(|i| ((i % 256) as i8).wrapping_sub(64))
1586 .collect();
1587 let b: Vec<i8> = (0..128)
1588 .map(|i| (((i + 5) % 256) as i8).wrapping_sub(64))
1589 .collect();
1590
1591 let result = euclidean_distance_squared_i8(&a, &b);
1592 let expected = euclidean_distance_squared_i8_scalar(&a, &b);
1593
1594 assert_eq!(
1595 result, expected,
1596 "Large INT8 euclidean^2: SIMD {} vs scalar {}",
1597 result, expected
1598 );
1599 }
1600
1601 #[test]
1603 fn test_batch_dot_product() {
1604 let query = vec![1.0, 2.0, 3.0, 4.0];
1605 let v1 = vec![1.0, 0.0, 0.0, 0.0];
1606 let v2 = vec![0.0, 1.0, 0.0, 0.0];
1607 let v3 = vec![0.0, 0.0, 1.0, 0.0];
1608 let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
1609 let mut results = vec![0.0; 3];
1610
1611 batch_dot_product(&query, &vectors, &mut results);
1612
1613 assert!((results[0] - 1.0).abs() < 0.001);
1614 assert!((results[1] - 2.0).abs() < 0.001);
1615 assert!((results[2] - 3.0).abs() < 0.001);
1616 }
1617
1618 #[test]
1619 fn test_batch_euclidean() {
1620 let query = vec![0.0, 0.0, 0.0, 0.0];
1621 let v1 = vec![3.0, 4.0, 0.0, 0.0];
1622 let v2 = vec![0.0, 0.0, 5.0, 12.0];
1623 let vectors: Vec<&[f32]> = vec![&v1, &v2];
1624 let mut results = vec![0.0; 2];
1625
1626 batch_euclidean(&query, &vectors, &mut results);
1627
1628 assert!(
1629 (results[0] - 5.0).abs() < 0.001,
1630 "Expected 5.0, got {}",
1631 results[0]
1632 );
1633 assert!(
1634 (results[1] - 13.0).abs() < 0.001,
1635 "Expected 13.0, got {}",
1636 results[1]
1637 );
1638 }
1639
1640 #[test]
1641 fn test_batch_cosine_similarity() {
1642 let query = vec![1.0, 0.0, 0.0, 0.0];
1643 let v1 = vec![1.0, 0.0, 0.0, 0.0]; let v2 = vec![0.0, 1.0, 0.0, 0.0]; let v3 = vec![-1.0, 0.0, 0.0, 0.0]; let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
1647 let mut results = vec![0.0; 3];
1648
1649 batch_cosine_similarity(&query, &vectors, &mut results);
1650
1651 assert!(
1652 (results[0] - 1.0).abs() < 0.001,
1653 "Same direction should be 1.0"
1654 );
1655 assert!(results[1].abs() < 0.001, "Orthogonal should be 0.0");
1656 assert!((results[2] + 1.0).abs() < 0.001, "Opposite should be -1.0");
1657 }
1658
1659 #[test]
1660 fn test_batch_owned_convenience() {
1661 let query = vec![1.0, 2.0, 3.0, 4.0];
1662 let vectors = vec![vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]];
1663
1664 let results = batch_dot_product_owned(&query, &vectors);
1665 assert_eq!(results.len(), 2);
1666 assert!((results[0] - 1.0).abs() < 0.001);
1667 assert!((results[1] - 2.0).abs() < 0.001);
1668 }
1669
1670 #[test]
1671 fn test_unrolled_vs_non_unrolled_consistency() {
1672 let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
1674 let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
1675
1676 let result = euclidean_distance_simd(&a, &b);
1677 let expected = euclidean_distance_scalar(&a, &b);
1678
1679 assert!(
1680 (result - expected).abs() < 0.01,
1681 "Unrolled consistency: SIMD {} vs scalar {}",
1682 result,
1683 expected
1684 );
1685 }
1686}