1#[cfg(target_arch = "x86_64")]
25use std::arch::x86_64::*;
26
27#[cfg(target_arch = "aarch64")]
28use std::arch::aarch64::*;
29
30#[allow(dead_code)]
32const PREFETCH_DISTANCE: usize = 64;
33
34#[inline(always)]
42pub fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
43 #[cfg(target_arch = "x86_64")]
44 {
45 if is_x86_feature_detected!("avx512f") {
46 unsafe { euclidean_distance_avx512_impl(a, b) }
47 } else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
48 unsafe { euclidean_distance_avx2_fma_impl(a, b) }
49 } else if is_x86_feature_detected!("avx2") {
50 unsafe { euclidean_distance_avx2_impl(a, b) }
51 } else {
52 euclidean_distance_scalar(a, b)
53 }
54 }
55
56 #[cfg(target_arch = "aarch64")]
57 {
58 if a.len() >= 64 {
60 unsafe { euclidean_distance_neon_unrolled_impl(a, b) }
61 } else {
62 unsafe { euclidean_distance_neon_impl(a, b) }
63 }
64 }
65
66 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
67 {
68 euclidean_distance_scalar(a, b)
69 }
70}
71
72#[inline(always)]
74pub fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
75 euclidean_distance_simd(a, b)
76}
77
78#[cfg(target_arch = "x86_64")]
79#[target_feature(enable = "avx2")]
80unsafe fn euclidean_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
81 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
83
84 let len = a.len();
85 let mut sum = _mm256_setzero_ps();
86
87 let chunks = len / 8;
89 for i in 0..chunks {
90 let idx = i * 8;
91
92 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
94 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
95
96 let diff = _mm256_sub_ps(va, vb);
98
99 let sq = _mm256_mul_ps(diff, diff);
101
102 sum = _mm256_add_ps(sum, sq);
104 }
105
106 let sum_arr: [f32; 8] = std::mem::transmute(sum);
108 let mut total = sum_arr.iter().sum::<f32>();
109
110 for i in (chunks * 8)..len {
112 let diff = a[i] - b[i];
113 total += diff * diff;
114 }
115
116 total.sqrt()
117}
118
119#[cfg(target_arch = "x86_64")]
121#[target_feature(enable = "avx2", enable = "fma")]
122unsafe fn euclidean_distance_avx2_fma_impl(a: &[f32], b: &[f32]) -> f32 {
123 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
124
125 let len = a.len();
126 let mut sum0 = _mm256_setzero_ps();
128 let mut sum1 = _mm256_setzero_ps();
129 let mut sum2 = _mm256_setzero_ps();
130 let mut sum3 = _mm256_setzero_ps();
131
132 let chunks = len / 32;
134 for i in 0..chunks {
135 let idx = i * 32;
136
137 let va0 = _mm256_loadu_ps(a.as_ptr().add(idx));
139 let vb0 = _mm256_loadu_ps(b.as_ptr().add(idx));
140 let diff0 = _mm256_sub_ps(va0, vb0);
141 sum0 = _mm256_fmadd_ps(diff0, diff0, sum0);
142
143 let va1 = _mm256_loadu_ps(a.as_ptr().add(idx + 8));
144 let vb1 = _mm256_loadu_ps(b.as_ptr().add(idx + 8));
145 let diff1 = _mm256_sub_ps(va1, vb1);
146 sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
147
148 let va2 = _mm256_loadu_ps(a.as_ptr().add(idx + 16));
149 let vb2 = _mm256_loadu_ps(b.as_ptr().add(idx + 16));
150 let diff2 = _mm256_sub_ps(va2, vb2);
151 sum2 = _mm256_fmadd_ps(diff2, diff2, sum2);
152
153 let va3 = _mm256_loadu_ps(a.as_ptr().add(idx + 24));
154 let vb3 = _mm256_loadu_ps(b.as_ptr().add(idx + 24));
155 let diff3 = _mm256_sub_ps(va3, vb3);
156 sum3 = _mm256_fmadd_ps(diff3, diff3, sum3);
157 }
158
159 let sum01 = _mm256_add_ps(sum0, sum1);
161 let sum23 = _mm256_add_ps(sum2, sum3);
162 let sum = _mm256_add_ps(sum01, sum23);
163
164 let remaining_start = chunks * 32;
166 let remaining_chunks = (len - remaining_start) / 8;
167 let mut final_sum = sum;
168 for i in 0..remaining_chunks {
169 let idx = remaining_start + i * 8;
170 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
171 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
172 let diff = _mm256_sub_ps(va, vb);
173 final_sum = _mm256_fmadd_ps(diff, diff, final_sum);
174 }
175
176 let sum_arr: [f32; 8] = std::mem::transmute(final_sum);
178 let mut total = sum_arr.iter().sum::<f32>();
179
180 let scalar_start = remaining_start + remaining_chunks * 8;
182 for i in scalar_start..len {
183 let diff = a[i] - b[i];
184 total += diff * diff;
185 }
186
187 total.sqrt()
188}
189
190#[cfg(target_arch = "x86_64")]
196#[target_feature(enable = "avx512f")]
197unsafe fn euclidean_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
198 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
199
200 let len = a.len();
201 let mut sum = _mm512_setzero_ps();
202
203 let chunks = len / 16;
205 for i in 0..chunks {
206 let idx = i * 16;
207 let va = _mm512_loadu_ps(a.as_ptr().add(idx));
208 let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
209 let diff = _mm512_sub_ps(va, vb);
210 sum = _mm512_fmadd_ps(diff, diff, sum);
211 }
212
213 let mut total = _mm512_reduce_add_ps(sum);
215
216 for i in (chunks * 16)..len {
218 let diff = a[i] - b[i];
219 total += diff * diff;
220 }
221
222 total.sqrt()
223}
224
225#[cfg(target_arch = "x86_64")]
227#[target_feature(enable = "avx512f")]
228unsafe fn dot_product_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
229 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
230
231 let len = a.len();
232 let mut sum = _mm512_setzero_ps();
233
234 let chunks = len / 16;
235 for i in 0..chunks {
236 let idx = i * 16;
237 let va = _mm512_loadu_ps(a.as_ptr().add(idx));
238 let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
239 sum = _mm512_fmadd_ps(va, vb, sum);
240 }
241
242 let mut total = _mm512_reduce_add_ps(sum);
243
244 for i in (chunks * 16)..len {
245 total += a[i] * b[i];
246 }
247
248 total
249}
250
251#[cfg(target_arch = "x86_64")]
253#[target_feature(enable = "avx512f")]
254unsafe fn cosine_similarity_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
255 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
256
257 let len = a.len();
258 let mut dot = _mm512_setzero_ps();
259 let mut norm_a = _mm512_setzero_ps();
260 let mut norm_b = _mm512_setzero_ps();
261
262 let chunks = len / 16;
263 for i in 0..chunks {
264 let idx = i * 16;
265 let va = _mm512_loadu_ps(a.as_ptr().add(idx));
266 let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
267
268 dot = _mm512_fmadd_ps(va, vb, dot);
269 norm_a = _mm512_fmadd_ps(va, va, norm_a);
270 norm_b = _mm512_fmadd_ps(vb, vb, norm_b);
271 }
272
273 let mut dot_sum = _mm512_reduce_add_ps(dot);
274 let mut norm_a_sum = _mm512_reduce_add_ps(norm_a);
275 let mut norm_b_sum = _mm512_reduce_add_ps(norm_b);
276
277 for i in (chunks * 16)..len {
278 dot_sum += a[i] * b[i];
279 norm_a_sum += a[i] * a[i];
280 norm_b_sum += b[i] * b[i];
281 }
282
283 dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
284}
285
286#[cfg(target_arch = "x86_64")]
288#[target_feature(enable = "avx512f")]
289unsafe fn manhattan_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
290 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
291
292 let len = a.len();
293 let mut sum = _mm512_setzero_ps();
294
295 let chunks = len / 16;
296 for i in 0..chunks {
297 let idx = i * 16;
298 let va = _mm512_loadu_ps(a.as_ptr().add(idx));
299 let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
300 let diff = _mm512_sub_ps(va, vb);
301 let abs_diff = _mm512_abs_ps(diff);
302 sum = _mm512_add_ps(sum, abs_diff);
303 }
304
305 let mut total = _mm512_reduce_add_ps(sum);
306
307 for i in (chunks * 16)..len {
308 total += (a[i] - b[i]).abs();
309 }
310
311 total
312}
313
314#[cfg(target_arch = "aarch64")]
324#[inline(always)]
325#[allow(dead_code)]
326unsafe fn euclidean_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
327 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
328
329 let len = a.len();
330 let mut sum = vdupq_n_f32(0.0);
331
332 let a_ptr = a.as_ptr();
333 let b_ptr = b.as_ptr();
334
335 let chunks = len / 4;
337 let mut idx = 0usize;
338
339 for _ in 0..chunks {
340 let va = vld1q_f32(a_ptr.add(idx));
341 let vb = vld1q_f32(b_ptr.add(idx));
342
343 let diff = vsubq_f32(va, vb);
345
346 sum = vfmaq_f32(sum, diff, diff);
348
349 idx += 4;
350 }
351
352 let mut total = vaddvq_f32(sum);
354
355 for i in (chunks * 4)..len {
357 let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
358 total += diff * diff;
359 }
360
361 total.sqrt()
362}
363
364#[cfg(target_arch = "aarch64")]
369#[inline(always)]
370unsafe fn dot_product_neon_impl(a: &[f32], b: &[f32]) -> f32 {
371 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
372
373 let len = a.len();
374 let mut sum = vdupq_n_f32(0.0);
375
376 let a_ptr = a.as_ptr();
377 let b_ptr = b.as_ptr();
378
379 let chunks = len / 4;
380 let mut idx = 0usize;
381
382 for _ in 0..chunks {
383 let va = vld1q_f32(a_ptr.add(idx));
384 let vb = vld1q_f32(b_ptr.add(idx));
385
386 sum = vfmaq_f32(sum, va, vb);
388
389 idx += 4;
390 }
391
392 let mut total = vaddvq_f32(sum);
393
394 for i in (chunks * 4)..len {
396 total += *a.get_unchecked(i) * *b.get_unchecked(i);
397 }
398
399 total
400}
401
402#[cfg(target_arch = "aarch64")]
407#[inline(always)]
408unsafe fn cosine_similarity_neon_impl(a: &[f32], b: &[f32]) -> f32 {
409 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
410
411 let len = a.len();
412 let mut dot = vdupq_n_f32(0.0);
413 let mut norm_a = vdupq_n_f32(0.0);
414 let mut norm_b = vdupq_n_f32(0.0);
415
416 let a_ptr = a.as_ptr();
417 let b_ptr = b.as_ptr();
418
419 let chunks = len / 4;
420 let mut idx = 0usize;
421
422 for _ in 0..chunks {
423 let va = vld1q_f32(a_ptr.add(idx));
424 let vb = vld1q_f32(b_ptr.add(idx));
425
426 dot = vfmaq_f32(dot, va, vb);
428
429 norm_a = vfmaq_f32(norm_a, va, va);
431 norm_b = vfmaq_f32(norm_b, vb, vb);
432
433 idx += 4;
434 }
435
436 let mut dot_sum = vaddvq_f32(dot);
437 let mut norm_a_sum = vaddvq_f32(norm_a);
438 let mut norm_b_sum = vaddvq_f32(norm_b);
439
440 for i in (chunks * 4)..len {
442 let ai = *a.get_unchecked(i);
443 let bi = *b.get_unchecked(i);
444 dot_sum += ai * bi;
445 norm_a_sum += ai * ai;
446 norm_b_sum += bi * bi;
447 }
448
449 dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
450}
451
452#[cfg(target_arch = "aarch64")]
457#[inline(always)]
458unsafe fn manhattan_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
459 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
460
461 let len = a.len();
462 let mut sum = vdupq_n_f32(0.0);
463
464 let a_ptr = a.as_ptr();
465 let b_ptr = b.as_ptr();
466
467 let chunks = len / 4;
468 let mut idx = 0usize;
469
470 for _ in 0..chunks {
471 let va = vld1q_f32(a_ptr.add(idx));
472 let vb = vld1q_f32(b_ptr.add(idx));
473
474 let abs_diff = vabdq_f32(va, vb);
476 sum = vaddq_f32(sum, abs_diff);
477
478 idx += 4;
479 }
480
481 let mut total = vaddvq_f32(sum);
482
483 for i in (chunks * 4)..len {
485 total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
486 }
487
488 total
489}
490
491#[cfg(target_arch = "aarch64")]
502#[inline(always)]
503unsafe fn euclidean_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
504 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
505
506 let len = a.len();
507 let a_ptr = a.as_ptr();
508 let b_ptr = b.as_ptr();
509
510 let mut sum0 = vdupq_n_f32(0.0);
512 let mut sum1 = vdupq_n_f32(0.0);
513 let mut sum2 = vdupq_n_f32(0.0);
514 let mut sum3 = vdupq_n_f32(0.0);
515
516 let chunks = len / 16;
518 let mut idx = 0usize;
519
520 for _ in 0..chunks {
521 let va0 = vld1q_f32(a_ptr.add(idx));
523 let vb0 = vld1q_f32(b_ptr.add(idx));
524 let diff0 = vsubq_f32(va0, vb0);
525 sum0 = vfmaq_f32(sum0, diff0, diff0);
526
527 let va1 = vld1q_f32(a_ptr.add(idx + 4));
528 let vb1 = vld1q_f32(b_ptr.add(idx + 4));
529 let diff1 = vsubq_f32(va1, vb1);
530 sum1 = vfmaq_f32(sum1, diff1, diff1);
531
532 let va2 = vld1q_f32(a_ptr.add(idx + 8));
533 let vb2 = vld1q_f32(b_ptr.add(idx + 8));
534 let diff2 = vsubq_f32(va2, vb2);
535 sum2 = vfmaq_f32(sum2, diff2, diff2);
536
537 let va3 = vld1q_f32(a_ptr.add(idx + 12));
538 let vb3 = vld1q_f32(b_ptr.add(idx + 12));
539 let diff3 = vsubq_f32(va3, vb3);
540 sum3 = vfmaq_f32(sum3, diff3, diff3);
541
542 idx += 16;
543 }
544
545 let sum01 = vaddq_f32(sum0, sum1);
547 let sum23 = vaddq_f32(sum2, sum3);
548 let sum = vaddq_f32(sum01, sum23);
549
550 let remaining_start = chunks * 16;
552 let remaining_chunks = (len - remaining_start) / 4;
553 let mut final_sum = sum;
554
555 idx = remaining_start;
556 for _ in 0..remaining_chunks {
557 let va = vld1q_f32(a_ptr.add(idx));
558 let vb = vld1q_f32(b_ptr.add(idx));
559 let diff = vsubq_f32(va, vb);
560 final_sum = vfmaq_f32(final_sum, diff, diff);
561 idx += 4;
562 }
563
564 let mut total = vaddvq_f32(final_sum);
566
567 let scalar_start = remaining_start + remaining_chunks * 4;
569 for i in scalar_start..len {
570 let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
571 total += diff * diff;
572 }
573
574 total.sqrt()
575}
576
577#[cfg(target_arch = "aarch64")]
582#[inline(always)]
583unsafe fn dot_product_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
584 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
585
586 let len = a.len();
587 let a_ptr = a.as_ptr();
588 let b_ptr = b.as_ptr();
589
590 let mut sum0 = vdupq_n_f32(0.0);
591 let mut sum1 = vdupq_n_f32(0.0);
592 let mut sum2 = vdupq_n_f32(0.0);
593 let mut sum3 = vdupq_n_f32(0.0);
594
595 let chunks = len / 16;
596 let mut idx = 0usize;
597
598 for _ in 0..chunks {
599 let va0 = vld1q_f32(a_ptr.add(idx));
600 let vb0 = vld1q_f32(b_ptr.add(idx));
601 sum0 = vfmaq_f32(sum0, va0, vb0);
602
603 let va1 = vld1q_f32(a_ptr.add(idx + 4));
604 let vb1 = vld1q_f32(b_ptr.add(idx + 4));
605 sum1 = vfmaq_f32(sum1, va1, vb1);
606
607 let va2 = vld1q_f32(a_ptr.add(idx + 8));
608 let vb2 = vld1q_f32(b_ptr.add(idx + 8));
609 sum2 = vfmaq_f32(sum2, va2, vb2);
610
611 let va3 = vld1q_f32(a_ptr.add(idx + 12));
612 let vb3 = vld1q_f32(b_ptr.add(idx + 12));
613 sum3 = vfmaq_f32(sum3, va3, vb3);
614
615 idx += 16;
616 }
617
618 let sum01 = vaddq_f32(sum0, sum1);
620 let sum23 = vaddq_f32(sum2, sum3);
621 let sum = vaddq_f32(sum01, sum23);
622
623 let remaining_start = chunks * 16;
624 let remaining_chunks = (len - remaining_start) / 4;
625 let mut final_sum = sum;
626
627 idx = remaining_start;
628 for _ in 0..remaining_chunks {
629 let va = vld1q_f32(a_ptr.add(idx));
630 let vb = vld1q_f32(b_ptr.add(idx));
631 final_sum = vfmaq_f32(final_sum, va, vb);
632 idx += 4;
633 }
634
635 let mut total = vaddvq_f32(final_sum);
636
637 let scalar_start = remaining_start + remaining_chunks * 4;
639 for i in scalar_start..len {
640 total += *a.get_unchecked(i) * *b.get_unchecked(i);
641 }
642
643 total
644}
645
646#[cfg(target_arch = "aarch64")]
651#[inline(always)]
652unsafe fn cosine_similarity_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
653 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
654
655 let len = a.len();
656 let a_ptr = a.as_ptr();
657 let b_ptr = b.as_ptr();
658
659 let mut dot0 = vdupq_n_f32(0.0);
660 let mut dot1 = vdupq_n_f32(0.0);
661 let mut norm_a0 = vdupq_n_f32(0.0);
662 let mut norm_a1 = vdupq_n_f32(0.0);
663 let mut norm_b0 = vdupq_n_f32(0.0);
664 let mut norm_b1 = vdupq_n_f32(0.0);
665
666 let chunks = len / 8;
667 let mut idx = 0usize;
668
669 for _ in 0..chunks {
670 let va0 = vld1q_f32(a_ptr.add(idx));
671 let vb0 = vld1q_f32(b_ptr.add(idx));
672 dot0 = vfmaq_f32(dot0, va0, vb0);
673 norm_a0 = vfmaq_f32(norm_a0, va0, va0);
674 norm_b0 = vfmaq_f32(norm_b0, vb0, vb0);
675
676 let va1 = vld1q_f32(a_ptr.add(idx + 4));
677 let vb1 = vld1q_f32(b_ptr.add(idx + 4));
678 dot1 = vfmaq_f32(dot1, va1, vb1);
679 norm_a1 = vfmaq_f32(norm_a1, va1, va1);
680 norm_b1 = vfmaq_f32(norm_b1, vb1, vb1);
681
682 idx += 8;
683 }
684
685 let dot = vaddq_f32(dot0, dot1);
687 let norm_a = vaddq_f32(norm_a0, norm_a1);
688 let norm_b = vaddq_f32(norm_b0, norm_b1);
689
690 let mut dot_sum = vaddvq_f32(dot);
691 let mut norm_a_sum = vaddvq_f32(norm_a);
692 let mut norm_b_sum = vaddvq_f32(norm_b);
693
694 for i in (chunks * 8)..len {
696 let ai = *a.get_unchecked(i);
697 let bi = *b.get_unchecked(i);
698 dot_sum += ai * bi;
699 norm_a_sum += ai * ai;
700 norm_b_sum += bi * bi;
701 }
702
703 dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
704}
705
706#[cfg(target_arch = "aarch64")]
711#[inline(always)]
712unsafe fn manhattan_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
713 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
714
715 let len = a.len();
716 let a_ptr = a.as_ptr();
717 let b_ptr = b.as_ptr();
718
719 let mut sum0 = vdupq_n_f32(0.0);
720 let mut sum1 = vdupq_n_f32(0.0);
721 let mut sum2 = vdupq_n_f32(0.0);
722 let mut sum3 = vdupq_n_f32(0.0);
723
724 let chunks = len / 16;
725 let mut idx = 0usize;
726
727 for _ in 0..chunks {
728 let va0 = vld1q_f32(a_ptr.add(idx));
730 let vb0 = vld1q_f32(b_ptr.add(idx));
731 sum0 = vaddq_f32(sum0, vabdq_f32(va0, vb0));
732
733 let va1 = vld1q_f32(a_ptr.add(idx + 4));
734 let vb1 = vld1q_f32(b_ptr.add(idx + 4));
735 sum1 = vaddq_f32(sum1, vabdq_f32(va1, vb1));
736
737 let va2 = vld1q_f32(a_ptr.add(idx + 8));
738 let vb2 = vld1q_f32(b_ptr.add(idx + 8));
739 sum2 = vaddq_f32(sum2, vabdq_f32(va2, vb2));
740
741 let va3 = vld1q_f32(a_ptr.add(idx + 12));
742 let vb3 = vld1q_f32(b_ptr.add(idx + 12));
743 sum3 = vaddq_f32(sum3, vabdq_f32(va3, vb3));
744
745 idx += 16;
746 }
747
748 let sum01 = vaddq_f32(sum0, sum1);
750 let sum23 = vaddq_f32(sum2, sum3);
751 let sum = vaddq_f32(sum01, sum23);
752
753 let remaining_start = chunks * 16;
754 let remaining_chunks = (len - remaining_start) / 4;
755 let mut final_sum = sum;
756
757 idx = remaining_start;
758 for _ in 0..remaining_chunks {
759 let va = vld1q_f32(a_ptr.add(idx));
760 let vb = vld1q_f32(b_ptr.add(idx));
761 final_sum = vaddq_f32(final_sum, vabdq_f32(va, vb));
762 idx += 4;
763 }
764
765 let mut total = vaddvq_f32(final_sum);
766
767 let scalar_start = remaining_start + remaining_chunks * 4;
769 for i in scalar_start..len {
770 total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
771 }
772
773 total
774}
775
776#[inline(always)]
783pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
784 #[cfg(target_arch = "x86_64")]
785 {
786 if is_x86_feature_detected!("avx512f") {
787 unsafe { dot_product_avx512_impl(a, b) }
788 } else if is_x86_feature_detected!("avx2") {
789 unsafe { dot_product_avx2_impl(a, b) }
790 } else {
791 dot_product_scalar(a, b)
792 }
793 }
794
795 #[cfg(target_arch = "aarch64")]
796 {
797 if a.len() >= 64 {
798 unsafe { dot_product_neon_unrolled_impl(a, b) }
799 } else {
800 unsafe { dot_product_neon_impl(a, b) }
801 }
802 }
803
804 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
805 {
806 dot_product_scalar(a, b)
807 }
808}
809
810#[inline(always)]
812pub fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
813 dot_product_simd(a, b)
814}
815
816#[cfg(target_arch = "x86_64")]
817#[target_feature(enable = "avx2")]
818unsafe fn dot_product_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
819 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
821
822 let len = a.len();
823 let mut sum = _mm256_setzero_ps();
824
825 let chunks = len / 8;
826 for i in 0..chunks {
827 let idx = i * 8;
828 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
829 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
830 let prod = _mm256_mul_ps(va, vb);
831 sum = _mm256_add_ps(sum, prod);
832 }
833
834 let sum_arr: [f32; 8] = std::mem::transmute(sum);
835 let mut total = sum_arr.iter().sum::<f32>();
836
837 for i in (chunks * 8)..len {
838 total += a[i] * b[i];
839 }
840
841 total
842}
843
844#[inline(always)]
847pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
848 #[cfg(target_arch = "x86_64")]
849 {
850 if is_x86_feature_detected!("avx512f") {
851 unsafe { cosine_similarity_avx512_impl(a, b) }
852 } else if is_x86_feature_detected!("avx2") {
853 unsafe { cosine_similarity_avx2_impl(a, b) }
854 } else {
855 cosine_similarity_scalar(a, b)
856 }
857 }
858
859 #[cfg(target_arch = "aarch64")]
860 {
861 if a.len() >= 64 {
862 unsafe { cosine_similarity_neon_unrolled_impl(a, b) }
863 } else {
864 unsafe { cosine_similarity_neon_impl(a, b) }
865 }
866 }
867
868 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
869 {
870 cosine_similarity_scalar(a, b)
871 }
872}
873
874#[inline(always)]
876pub fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
877 cosine_similarity_simd(a, b)
878}
879
880#[inline(always)]
883pub fn manhattan_distance_simd(a: &[f32], b: &[f32]) -> f32 {
884 #[cfg(target_arch = "x86_64")]
885 {
886 if is_x86_feature_detected!("avx512f") {
887 unsafe { manhattan_distance_avx512_impl(a, b) }
888 } else if is_x86_feature_detected!("avx2") {
889 unsafe { manhattan_distance_avx2_impl(a, b) }
890 } else {
891 manhattan_distance_scalar(a, b)
892 }
893 }
894
895 #[cfg(target_arch = "aarch64")]
896 {
897 if a.len() >= 64 {
898 unsafe { manhattan_distance_neon_unrolled_impl(a, b) }
899 } else {
900 unsafe { manhattan_distance_neon_impl(a, b) }
901 }
902 }
903
904 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
905 {
906 manhattan_distance_scalar(a, b)
907 }
908}
909
910#[cfg(target_arch = "x86_64")]
911#[target_feature(enable = "avx2")]
912unsafe fn cosine_similarity_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
913 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
915
916 let len = a.len();
917 let mut dot = _mm256_setzero_ps();
918 let mut norm_a = _mm256_setzero_ps();
919 let mut norm_b = _mm256_setzero_ps();
920
921 let chunks = len / 8;
922 for i in 0..chunks {
923 let idx = i * 8;
924 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
925 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
926
927 dot = _mm256_add_ps(dot, _mm256_mul_ps(va, vb));
929
930 norm_a = _mm256_add_ps(norm_a, _mm256_mul_ps(va, va));
932 norm_b = _mm256_add_ps(norm_b, _mm256_mul_ps(vb, vb));
933 }
934
935 let dot_arr: [f32; 8] = std::mem::transmute(dot);
936 let norm_a_arr: [f32; 8] = std::mem::transmute(norm_a);
937 let norm_b_arr: [f32; 8] = std::mem::transmute(norm_b);
938
939 let mut dot_sum = dot_arr.iter().sum::<f32>();
940 let mut norm_a_sum = norm_a_arr.iter().sum::<f32>();
941 let mut norm_b_sum = norm_b_arr.iter().sum::<f32>();
942
943 for i in (chunks * 8)..len {
944 dot_sum += a[i] * b[i];
945 norm_a_sum += a[i] * a[i];
946 norm_b_sum += b[i] * b[i];
947 }
948
949 dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
950}
951
952#[cfg(target_arch = "x86_64")]
954#[target_feature(enable = "avx2")]
955unsafe fn manhattan_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
956 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
957
958 let len = a.len();
959 let sign_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
961 let mut sum0 = _mm256_setzero_ps();
962 let mut sum1 = _mm256_setzero_ps();
963
964 let chunks = len / 16;
966 for i in 0..chunks {
967 let idx = i * 16;
968
969 let va0 = _mm256_loadu_ps(a.as_ptr().add(idx));
970 let vb0 = _mm256_loadu_ps(b.as_ptr().add(idx));
971 let diff0 = _mm256_sub_ps(va0, vb0);
972 let abs0 = _mm256_and_ps(diff0, sign_mask);
973 sum0 = _mm256_add_ps(sum0, abs0);
974
975 let va1 = _mm256_loadu_ps(a.as_ptr().add(idx + 8));
976 let vb1 = _mm256_loadu_ps(b.as_ptr().add(idx + 8));
977 let diff1 = _mm256_sub_ps(va1, vb1);
978 let abs1 = _mm256_and_ps(diff1, sign_mask);
979 sum1 = _mm256_add_ps(sum1, abs1);
980 }
981
982 let mut sum = _mm256_add_ps(sum0, sum1);
983
984 let remaining_start = chunks * 16;
986 let remaining_chunks = (len - remaining_start) / 8;
987 for i in 0..remaining_chunks {
988 let idx = remaining_start + i * 8;
989 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
990 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
991 let diff = _mm256_sub_ps(va, vb);
992 let abs_diff = _mm256_and_ps(diff, sign_mask);
993 sum = _mm256_add_ps(sum, abs_diff);
994 }
995
996 let sum_arr: [f32; 8] = std::mem::transmute(sum);
998 let mut total = sum_arr.iter().sum::<f32>();
999
1000 let scalar_start = remaining_start + remaining_chunks * 8;
1002 for i in scalar_start..len {
1003 total += (a[i] - b[i]).abs();
1004 }
1005
1006 total
1007}
1008
1009#[allow(dead_code)]
1013fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
1014 a.iter()
1015 .zip(b.iter())
1016 .map(|(x, y)| {
1017 let diff = x - y;
1018 diff * diff
1019 })
1020 .sum::<f32>()
1021 .sqrt()
1022}
1023
1024#[allow(dead_code)]
1025fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
1026 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
1027}
1028
1029#[allow(dead_code)]
1030fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
1031 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
1032 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
1033 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
1034 dot / (norm_a * norm_b)
1035}
1036
1037#[allow(dead_code)]
1038fn manhattan_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
1039 a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
1040}
1041
1042#[inline(always)]
1049pub fn dot_product_i8(a: &[i8], b: &[i8]) -> i32 {
1050 #[cfg(target_arch = "x86_64")]
1051 {
1052 if is_x86_feature_detected!("avx2") {
1053 unsafe { dot_product_i8_avx2_impl(a, b) }
1054 } else {
1055 dot_product_i8_scalar(a, b)
1056 }
1057 }
1058
1059 #[cfg(target_arch = "aarch64")]
1060 {
1061 unsafe { dot_product_i8_neon_impl(a, b) }
1062 }
1063
1064 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1065 {
1066 dot_product_i8_scalar(a, b)
1067 }
1068}
1069
1070#[inline(always)]
1073pub fn euclidean_distance_squared_i8(a: &[i8], b: &[i8]) -> i32 {
1074 #[cfg(target_arch = "x86_64")]
1075 {
1076 if is_x86_feature_detected!("avx2") {
1077 unsafe { euclidean_distance_squared_i8_avx2_impl(a, b) }
1078 } else {
1079 euclidean_distance_squared_i8_scalar(a, b)
1080 }
1081 }
1082
1083 #[cfg(target_arch = "aarch64")]
1084 {
1085 unsafe { euclidean_distance_squared_i8_neon_impl(a, b) }
1086 }
1087
1088 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1089 {
1090 euclidean_distance_squared_i8_scalar(a, b)
1091 }
1092}
1093
1094#[cfg(target_arch = "aarch64")]
1100#[inline(always)]
1101unsafe fn dot_product_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
1102 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1103
1104 let len = a.len();
1105 let a_ptr = a.as_ptr();
1106 let b_ptr = b.as_ptr();
1107
1108 let mut sum = vdupq_n_s32(0);
1109
1110 let chunks = len / 8;
1112 let mut idx = 0usize;
1113
1114 for _ in 0..chunks {
1115 let va = vld1_s8(a_ptr.add(idx));
1116 let vb = vld1_s8(b_ptr.add(idx));
1117
1118 let va_i16 = vmovl_s8(va);
1120 let vb_i16 = vmovl_s8(vb);
1121
1122 let prod_lo = vmull_s16(vget_low_s16(va_i16), vget_low_s16(vb_i16));
1124 let prod_hi = vmull_s16(vget_high_s16(va_i16), vget_high_s16(vb_i16));
1125
1126 sum = vaddq_s32(sum, prod_lo);
1128 sum = vaddq_s32(sum, prod_hi);
1129
1130 idx += 8;
1131 }
1132
1133 let mut total = vaddvq_s32(sum);
1135
1136 for i in (chunks * 8)..len {
1138 total += (*a.get_unchecked(i) as i32) * (*b.get_unchecked(i) as i32);
1139 }
1140
1141 total
1142}
1143
1144#[cfg(target_arch = "aarch64")]
1149#[inline(always)]
1150unsafe fn euclidean_distance_squared_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
1151 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1152
1153 let len = a.len();
1154 let a_ptr = a.as_ptr();
1155 let b_ptr = b.as_ptr();
1156
1157 let mut sum = vdupq_n_s32(0);
1158
1159 let chunks = len / 8;
1161 let mut idx = 0usize;
1162
1163 for _ in 0..chunks {
1164 let va = vld1_s8(a_ptr.add(idx));
1165 let vb = vld1_s8(b_ptr.add(idx));
1166
1167 let va_i16 = vmovl_s8(va);
1169 let vb_i16 = vmovl_s8(vb);
1170
1171 let diff = vsubq_s16(va_i16, vb_i16);
1173
1174 let prod_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
1176 let prod_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
1177
1178 sum = vaddq_s32(sum, prod_lo);
1179 sum = vaddq_s32(sum, prod_hi);
1180
1181 idx += 8;
1182 }
1183
1184 let mut total = vaddvq_s32(sum);
1185
1186 for i in (chunks * 8)..len {
1188 let diff = (*a.get_unchecked(i) as i32) - (*b.get_unchecked(i) as i32);
1189 total += diff * diff;
1190 }
1191
1192 total
1193}
1194
1195#[cfg(target_arch = "x86_64")]
1197#[target_feature(enable = "avx2")]
1198unsafe fn dot_product_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
1199 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1200
1201 let len = a.len();
1202 let mut sum = _mm256_setzero_si256();
1203
1204 let chunks = len / 32;
1206 for i in 0..chunks {
1207 let idx = i * 32;
1208 let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
1209 let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
1210
1211 let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
1213 let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
1214 let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
1215 let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
1216
1217 let prod_lo = _mm256_madd_epi16(va_lo, vb_lo);
1218 let prod_hi = _mm256_madd_epi16(va_hi, vb_hi);
1219
1220 sum = _mm256_add_epi32(sum, prod_lo);
1221 sum = _mm256_add_epi32(sum, prod_hi);
1222 }
1223
1224 let sum_arr: [i32; 8] = std::mem::transmute(sum);
1226 let mut total: i32 = sum_arr.iter().sum();
1227
1228 for i in (chunks * 32)..len {
1230 total += (a[i] as i32) * (b[i] as i32);
1231 }
1232
1233 total
1234}
1235
1236#[cfg(target_arch = "x86_64")]
1238#[target_feature(enable = "avx2")]
1239unsafe fn euclidean_distance_squared_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
1240 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1241
1242 let len = a.len();
1243 let mut sum = _mm256_setzero_si256();
1244
1245 let chunks = len / 32;
1246 for i in 0..chunks {
1247 let idx = i * 32;
1248 let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
1249 let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
1250
1251 let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
1253 let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
1254 let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
1255 let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
1256
1257 let diff_lo = _mm256_sub_epi16(va_lo, vb_lo);
1258 let diff_hi = _mm256_sub_epi16(va_hi, vb_hi);
1259
1260 let sq_lo = _mm256_madd_epi16(diff_lo, diff_lo);
1261 let sq_hi = _mm256_madd_epi16(diff_hi, diff_hi);
1262
1263 sum = _mm256_add_epi32(sum, sq_lo);
1264 sum = _mm256_add_epi32(sum, sq_hi);
1265 }
1266
1267 let sum_arr: [i32; 8] = std::mem::transmute(sum);
1268 let mut total: i32 = sum_arr.iter().sum();
1269
1270 for i in (chunks * 32)..len {
1271 let diff = (a[i] as i32) - (b[i] as i32);
1272 total += diff * diff;
1273 }
1274
1275 total
1276}
1277
1278#[allow(dead_code)]
1280fn dot_product_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
1281 a.iter()
1282 .zip(b.iter())
1283 .map(|(&x, &y)| (x as i32) * (y as i32))
1284 .sum()
1285}
1286
1287#[allow(dead_code)]
1289fn euclidean_distance_squared_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
1290 a.iter()
1291 .zip(b.iter())
1292 .map(|(&x, &y)| {
1293 let diff = (x as i32) - (y as i32);
1294 diff * diff
1295 })
1296 .sum()
1297}
1298
1299#[inline]
1307pub fn batch_dot_product(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1308 assert_eq!(
1309 vectors.len(),
1310 results.len(),
1311 "Output size must match vector count"
1312 );
1313
1314 const TILE_SIZE: usize = 16;
1316
1317 for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1318 let base_idx = chunk_idx * TILE_SIZE;
1319 for (i, vec) in chunk.iter().enumerate() {
1320 results[base_idx + i] = dot_product_simd(query, vec);
1321 }
1322 }
1323}
1324
1325#[inline]
1329pub fn batch_euclidean(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1330 assert_eq!(
1331 vectors.len(),
1332 results.len(),
1333 "Output size must match vector count"
1334 );
1335
1336 const TILE_SIZE: usize = 16;
1337
1338 for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1339 let base_idx = chunk_idx * TILE_SIZE;
1340 for (i, vec) in chunk.iter().enumerate() {
1341 results[base_idx + i] = euclidean_distance_simd(query, vec);
1342 }
1343 }
1344}
1345
1346#[inline]
1348pub fn batch_cosine_similarity(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1349 assert_eq!(
1350 vectors.len(),
1351 results.len(),
1352 "Output size must match vector count"
1353 );
1354
1355 const TILE_SIZE: usize = 16;
1356
1357 for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1358 let base_idx = chunk_idx * TILE_SIZE;
1359 for (i, vec) in chunk.iter().enumerate() {
1360 results[base_idx + i] = cosine_similarity_simd(query, vec);
1361 }
1362 }
1363}
1364
1365#[inline]
1367pub fn batch_dot_product_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
1368 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1369 let mut results = vec![0.0; vectors.len()];
1370 batch_dot_product(query, &refs, &mut results);
1371 results
1372}
1373
1374#[inline]
1376pub fn batch_euclidean_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
1377 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1378 let mut results = vec![0.0; vectors.len()];
1379 batch_euclidean(query, &refs, &mut results);
1380 results
1381}
1382
1383#[cfg(test)]
1384mod tests {
1385 use super::*;
1386
1387 #[test]
1388 fn test_euclidean_distance_simd() {
1389 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1390 let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
1391
1392 let result = euclidean_distance_simd(&a, &b);
1393 let expected = euclidean_distance_scalar(&a, &b);
1394
1395 assert!(
1396 (result - expected).abs() < 0.001,
1397 "SIMD result {} differs from scalar result {}",
1398 result,
1399 expected
1400 );
1401 }
1402
1403 #[test]
1404 fn test_euclidean_distance_large() {
1405 let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
1407 let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
1408
1409 let result = euclidean_distance_simd(&a, &b);
1410 let expected = euclidean_distance_scalar(&a, &b);
1411
1412 assert!(
1413 (result - expected).abs() < 0.01,
1414 "Large vector: SIMD {} vs scalar {}",
1415 result,
1416 expected
1417 );
1418 }
1419
1420 #[test]
1421 fn test_dot_product_simd() {
1422 let a = vec![1.0; 16];
1423 let b = vec![2.0; 16];
1424
1425 let result = dot_product_simd(&a, &b);
1426 assert!((result - 32.0).abs() < 0.001);
1427 }
1428
1429 #[test]
1430 fn test_dot_product_large() {
1431 let a: Vec<f32> = (0..256).map(|i| (i % 10) as f32).collect();
1432 let b: Vec<f32> = (0..256).map(|i| ((i + 5) % 10) as f32).collect();
1433
1434 let result = dot_product_simd(&a, &b);
1435 let expected = dot_product_scalar(&a, &b);
1436
1437 assert!(
1438 (result - expected).abs() < 0.1,
1439 "Large dot product: SIMD {} vs scalar {}",
1440 result,
1441 expected
1442 );
1443 }
1444
1445 #[test]
1446 fn test_cosine_similarity_simd() {
1447 let a = vec![1.0, 0.0, 0.0];
1448 let b = vec![1.0, 0.0, 0.0];
1449
1450 let result = cosine_similarity_simd(&a, &b);
1451 assert!((result - 1.0).abs() < 0.001);
1452 }
1453
1454 #[test]
1455 fn test_cosine_similarity_orthogonal() {
1456 let a = vec![1.0, 0.0, 0.0, 0.0];
1457 let b = vec![0.0, 1.0, 0.0, 0.0];
1458
1459 let result = cosine_similarity_simd(&a, &b);
1460 assert!(
1461 result.abs() < 0.001,
1462 "Orthogonal vectors should have ~0 similarity, got {}",
1463 result
1464 );
1465 }
1466
1467 #[test]
1468 fn test_manhattan_distance_simd() {
1469 let a = vec![1.0, 2.0, 3.0, 4.0];
1470 let b = vec![5.0, 6.0, 7.0, 8.0];
1471
1472 let result = manhattan_distance_simd(&a, &b);
1473 let expected = manhattan_distance_scalar(&a, &b);
1474
1475 assert!(
1476 (result - expected).abs() < 0.001,
1477 "Manhattan: SIMD {} vs scalar {}",
1478 result,
1479 expected
1480 );
1481 assert!((result - 16.0).abs() < 0.001); }
1483
1484 #[test]
1485 fn test_non_aligned_lengths() {
1486 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]; let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1489
1490 let result = euclidean_distance_simd(&a, &b);
1491 let expected = euclidean_distance_scalar(&a, &b);
1492
1493 assert!(
1494 (result - expected).abs() < 0.001,
1495 "Non-aligned: SIMD {} vs scalar {}",
1496 result,
1497 expected
1498 );
1499 }
1500
1501 #[test]
1503 fn test_legacy_avx2_aliases() {
1504 let a = vec![1.0, 2.0, 3.0, 4.0];
1505 let b = vec![5.0, 6.0, 7.0, 8.0];
1506
1507 let _ = euclidean_distance_avx2(&a, &b);
1509 let _ = dot_product_avx2(&a, &b);
1510 let _ = cosine_similarity_avx2(&a, &b);
1511 }
1512
1513 #[test]
1515 fn test_dot_product_i8() {
1516 let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1517 let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
1518
1519 let result = dot_product_i8(&a, &b);
1520 let expected = dot_product_i8_scalar(&a, &b);
1521
1522 assert_eq!(
1523 result, expected,
1524 "INT8 dot product: SIMD {} vs scalar {}",
1525 result, expected
1526 );
1527 }
1528
1529 #[test]
1530 fn test_dot_product_i8_large() {
1531 let a: Vec<i8> = (0..128)
1533 .map(|i| ((i % 256) as i8).wrapping_sub(64))
1534 .collect();
1535 let b: Vec<i8> = (0..128)
1536 .map(|i| (((i + 10) % 256) as i8).wrapping_sub(64))
1537 .collect();
1538
1539 let result = dot_product_i8(&a, &b);
1540 let expected = dot_product_i8_scalar(&a, &b);
1541
1542 assert_eq!(
1543 result, expected,
1544 "Large INT8 dot product: SIMD {} vs scalar {}",
1545 result, expected
1546 );
1547 }
1548
1549 #[test]
1550 fn test_euclidean_distance_squared_i8() {
1551 let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1552 let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
1553
1554 let result = euclidean_distance_squared_i8(&a, &b);
1555 let expected = euclidean_distance_squared_i8_scalar(&a, &b);
1556
1557 assert_eq!(
1558 result, expected,
1559 "INT8 euclidean^2: SIMD {} vs scalar {}",
1560 result, expected
1561 );
1562 assert_eq!(result, 16, "Expected 16, got {}", result);
1564 }
1565
1566 #[test]
1567 fn test_euclidean_distance_squared_i8_large() {
1568 let a: Vec<i8> = (0..128)
1569 .map(|i| ((i % 256) as i8).wrapping_sub(64))
1570 .collect();
1571 let b: Vec<i8> = (0..128)
1572 .map(|i| (((i + 5) % 256) as i8).wrapping_sub(64))
1573 .collect();
1574
1575 let result = euclidean_distance_squared_i8(&a, &b);
1576 let expected = euclidean_distance_squared_i8_scalar(&a, &b);
1577
1578 assert_eq!(
1579 result, expected,
1580 "Large INT8 euclidean^2: SIMD {} vs scalar {}",
1581 result, expected
1582 );
1583 }
1584
1585 #[test]
1587 fn test_batch_dot_product() {
1588 let query = vec![1.0, 2.0, 3.0, 4.0];
1589 let v1 = vec![1.0, 0.0, 0.0, 0.0];
1590 let v2 = vec![0.0, 1.0, 0.0, 0.0];
1591 let v3 = vec![0.0, 0.0, 1.0, 0.0];
1592 let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
1593 let mut results = vec![0.0; 3];
1594
1595 batch_dot_product(&query, &vectors, &mut results);
1596
1597 assert!((results[0] - 1.0).abs() < 0.001);
1598 assert!((results[1] - 2.0).abs() < 0.001);
1599 assert!((results[2] - 3.0).abs() < 0.001);
1600 }
1601
1602 #[test]
1603 fn test_batch_euclidean() {
1604 let query = vec![0.0, 0.0, 0.0, 0.0];
1605 let v1 = vec![3.0, 4.0, 0.0, 0.0];
1606 let v2 = vec![0.0, 0.0, 5.0, 12.0];
1607 let vectors: Vec<&[f32]> = vec![&v1, &v2];
1608 let mut results = vec![0.0; 2];
1609
1610 batch_euclidean(&query, &vectors, &mut results);
1611
1612 assert!(
1613 (results[0] - 5.0).abs() < 0.001,
1614 "Expected 5.0, got {}",
1615 results[0]
1616 );
1617 assert!(
1618 (results[1] - 13.0).abs() < 0.001,
1619 "Expected 13.0, got {}",
1620 results[1]
1621 );
1622 }
1623
1624 #[test]
1625 fn test_batch_cosine_similarity() {
1626 let query = vec![1.0, 0.0, 0.0, 0.0];
1627 let v1 = vec![1.0, 0.0, 0.0, 0.0]; let v2 = vec![0.0, 1.0, 0.0, 0.0]; let v3 = vec![-1.0, 0.0, 0.0, 0.0]; let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
1631 let mut results = vec![0.0; 3];
1632
1633 batch_cosine_similarity(&query, &vectors, &mut results);
1634
1635 assert!(
1636 (results[0] - 1.0).abs() < 0.001,
1637 "Same direction should be 1.0"
1638 );
1639 assert!(results[1].abs() < 0.001, "Orthogonal should be 0.0");
1640 assert!((results[2] + 1.0).abs() < 0.001, "Opposite should be -1.0");
1641 }
1642
1643 #[test]
1644 fn test_batch_owned_convenience() {
1645 let query = vec![1.0, 2.0, 3.0, 4.0];
1646 let vectors = vec![vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]];
1647
1648 let results = batch_dot_product_owned(&query, &vectors);
1649 assert_eq!(results.len(), 2);
1650 assert!((results[0] - 1.0).abs() < 0.001);
1651 assert!((results[1] - 2.0).abs() < 0.001);
1652 }
1653
1654 #[test]
1655 fn test_unrolled_vs_non_unrolled_consistency() {
1656 let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
1658 let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
1659
1660 let result = euclidean_distance_simd(&a, &b);
1661 let expected = euclidean_distance_scalar(&a, &b);
1662
1663 assert!(
1664 (result - expected).abs() < 0.01,
1665 "Unrolled consistency: SIMD {} vs scalar {}",
1666 result,
1667 expected
1668 );
1669 }
1670}