1#[cfg(target_arch = "x86_64")]
25use std::arch::x86_64::*;
26
27#[cfg(target_arch = "aarch64")]
28use std::arch::aarch64::*;
29
30#[allow(dead_code)]
32const PREFETCH_DISTANCE: usize = 64;
33
34#[inline(always)]
42pub fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
43 #[cfg(target_arch = "x86_64")]
44 {
45 if is_x86_feature_detected!("avx512f") {
46 unsafe { euclidean_distance_avx512_impl(a, b) }
47 } else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
48 unsafe { euclidean_distance_avx2_fma_impl(a, b) }
49 } else if is_x86_feature_detected!("avx2") {
50 unsafe { euclidean_distance_avx2_impl(a, b) }
51 } else {
52 euclidean_distance_scalar(a, b)
53 }
54 }
55
56 #[cfg(target_arch = "aarch64")]
57 {
58 if a.len() >= 64 {
60 unsafe { euclidean_distance_neon_unrolled_impl(a, b) }
61 } else {
62 unsafe { euclidean_distance_neon_impl(a, b) }
63 }
64 }
65
66 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
67 {
68 euclidean_distance_scalar(a, b)
69 }
70}
71
72#[inline(always)]
74pub fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
75 euclidean_distance_simd(a, b)
76}
77
78#[cfg(target_arch = "x86_64")]
79#[target_feature(enable = "avx2")]
80unsafe fn euclidean_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
81 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
83
84 let len = a.len();
85 let mut sum = _mm256_setzero_ps();
86
87 let chunks = len / 8;
89 for i in 0..chunks {
90 let idx = i * 8;
91
92 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
94 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
95
96 let diff = _mm256_sub_ps(va, vb);
98
99 let sq = _mm256_mul_ps(diff, diff);
101
102 sum = _mm256_add_ps(sum, sq);
104 }
105
106 let sum_arr: [f32; 8] = std::mem::transmute(sum);
108 let mut total = sum_arr.iter().sum::<f32>();
109
110 for i in (chunks * 8)..len {
112 let diff = a[i] - b[i];
113 total += diff * diff;
114 }
115
116 total.sqrt()
117}
118
119#[cfg(target_arch = "x86_64")]
121#[target_feature(enable = "avx2", enable = "fma")]
122unsafe fn euclidean_distance_avx2_fma_impl(a: &[f32], b: &[f32]) -> f32 {
123 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
124
125 let len = a.len();
126 let mut sum0 = _mm256_setzero_ps();
128 let mut sum1 = _mm256_setzero_ps();
129 let mut sum2 = _mm256_setzero_ps();
130 let mut sum3 = _mm256_setzero_ps();
131
132 let chunks = len / 32;
134 for i in 0..chunks {
135 let idx = i * 32;
136
137 let va0 = _mm256_loadu_ps(a.as_ptr().add(idx));
139 let vb0 = _mm256_loadu_ps(b.as_ptr().add(idx));
140 let diff0 = _mm256_sub_ps(va0, vb0);
141 sum0 = _mm256_fmadd_ps(diff0, diff0, sum0);
142
143 let va1 = _mm256_loadu_ps(a.as_ptr().add(idx + 8));
144 let vb1 = _mm256_loadu_ps(b.as_ptr().add(idx + 8));
145 let diff1 = _mm256_sub_ps(va1, vb1);
146 sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
147
148 let va2 = _mm256_loadu_ps(a.as_ptr().add(idx + 16));
149 let vb2 = _mm256_loadu_ps(b.as_ptr().add(idx + 16));
150 let diff2 = _mm256_sub_ps(va2, vb2);
151 sum2 = _mm256_fmadd_ps(diff2, diff2, sum2);
152
153 let va3 = _mm256_loadu_ps(a.as_ptr().add(idx + 24));
154 let vb3 = _mm256_loadu_ps(b.as_ptr().add(idx + 24));
155 let diff3 = _mm256_sub_ps(va3, vb3);
156 sum3 = _mm256_fmadd_ps(diff3, diff3, sum3);
157 }
158
159 let sum01 = _mm256_add_ps(sum0, sum1);
161 let sum23 = _mm256_add_ps(sum2, sum3);
162 let sum = _mm256_add_ps(sum01, sum23);
163
164 let remaining_start = chunks * 32;
166 let remaining_chunks = (len - remaining_start) / 8;
167 let mut final_sum = sum;
168 for i in 0..remaining_chunks {
169 let idx = remaining_start + i * 8;
170 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
171 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
172 let diff = _mm256_sub_ps(va, vb);
173 final_sum = _mm256_fmadd_ps(diff, diff, final_sum);
174 }
175
176 let sum_arr: [f32; 8] = std::mem::transmute(final_sum);
178 let mut total = sum_arr.iter().sum::<f32>();
179
180 let scalar_start = remaining_start + remaining_chunks * 8;
182 for i in scalar_start..len {
183 let diff = a[i] - b[i];
184 total += diff * diff;
185 }
186
187 total.sqrt()
188}
189
190#[cfg(target_arch = "x86_64")]
196#[target_feature(enable = "avx512f")]
197unsafe fn euclidean_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
198 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
199
200 let len = a.len();
201 let mut sum = _mm512_setzero_ps();
202
203 let chunks = len / 16;
205 for i in 0..chunks {
206 let idx = i * 16;
207 let va = _mm512_loadu_ps(a.as_ptr().add(idx));
208 let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
209 let diff = _mm512_sub_ps(va, vb);
210 sum = _mm512_fmadd_ps(diff, diff, sum);
211 }
212
213 let mut total = _mm512_reduce_add_ps(sum);
215
216 for i in (chunks * 16)..len {
218 let diff = a[i] - b[i];
219 total += diff * diff;
220 }
221
222 total.sqrt()
223}
224
225#[cfg(target_arch = "x86_64")]
227#[target_feature(enable = "avx512f")]
228unsafe fn dot_product_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
229 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
230
231 let len = a.len();
232 let mut sum = _mm512_setzero_ps();
233
234 let chunks = len / 16;
235 for i in 0..chunks {
236 let idx = i * 16;
237 let va = _mm512_loadu_ps(a.as_ptr().add(idx));
238 let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
239 sum = _mm512_fmadd_ps(va, vb, sum);
240 }
241
242 let mut total = _mm512_reduce_add_ps(sum);
243
244 for i in (chunks * 16)..len {
245 total += a[i] * b[i];
246 }
247
248 total
249}
250
251#[cfg(target_arch = "x86_64")]
253#[target_feature(enable = "avx512f")]
254unsafe fn cosine_similarity_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
255 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
256
257 let len = a.len();
258 let mut dot = _mm512_setzero_ps();
259 let mut norm_a = _mm512_setzero_ps();
260 let mut norm_b = _mm512_setzero_ps();
261
262 let chunks = len / 16;
263 for i in 0..chunks {
264 let idx = i * 16;
265 let va = _mm512_loadu_ps(a.as_ptr().add(idx));
266 let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
267
268 dot = _mm512_fmadd_ps(va, vb, dot);
269 norm_a = _mm512_fmadd_ps(va, va, norm_a);
270 norm_b = _mm512_fmadd_ps(vb, vb, norm_b);
271 }
272
273 let mut dot_sum = _mm512_reduce_add_ps(dot);
274 let mut norm_a_sum = _mm512_reduce_add_ps(norm_a);
275 let mut norm_b_sum = _mm512_reduce_add_ps(norm_b);
276
277 for i in (chunks * 16)..len {
278 dot_sum += a[i] * b[i];
279 norm_a_sum += a[i] * a[i];
280 norm_b_sum += b[i] * b[i];
281 }
282
283 dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
284}
285
286#[cfg(target_arch = "x86_64")]
288#[target_feature(enable = "avx512f")]
289unsafe fn manhattan_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
290 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
291
292 let len = a.len();
293 let mut sum = _mm512_setzero_ps();
294
295 let chunks = len / 16;
296 for i in 0..chunks {
297 let idx = i * 16;
298 let va = _mm512_loadu_ps(a.as_ptr().add(idx));
299 let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
300 let diff = _mm512_sub_ps(va, vb);
301 let abs_diff = _mm512_abs_ps(diff);
302 sum = _mm512_add_ps(sum, abs_diff);
303 }
304
305 let mut total = _mm512_reduce_add_ps(sum);
306
307 for i in (chunks * 16)..len {
308 total += (a[i] - b[i]).abs();
309 }
310
311 total
312}
313
314#[cfg(target_arch = "aarch64")]
324#[inline(always)]
325#[allow(dead_code)]
326unsafe fn euclidean_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
327 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
328
329 let len = a.len();
330 let mut sum = vdupq_n_f32(0.0);
331
332 let a_ptr = a.as_ptr();
333 let b_ptr = b.as_ptr();
334
335 let chunks = len / 4;
337 let mut idx = 0usize;
338
339 for _ in 0..chunks {
340 let va = vld1q_f32(a_ptr.add(idx));
341 let vb = vld1q_f32(b_ptr.add(idx));
342
343 let diff = vsubq_f32(va, vb);
345
346 sum = vfmaq_f32(sum, diff, diff);
348
349 idx += 4;
350 }
351
352 let mut total = vaddvq_f32(sum);
354
355 for i in (chunks * 4)..len {
357 let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
358 total += diff * diff;
359 }
360
361 total.sqrt()
362}
363
364#[cfg(target_arch = "aarch64")]
369#[inline(always)]
370unsafe fn dot_product_neon_impl(a: &[f32], b: &[f32]) -> f32 {
371 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
372
373 let len = a.len();
374 let mut sum = vdupq_n_f32(0.0);
375
376 let a_ptr = a.as_ptr();
377 let b_ptr = b.as_ptr();
378
379 let chunks = len / 4;
380 let mut idx = 0usize;
381
382 for _ in 0..chunks {
383 let va = vld1q_f32(a_ptr.add(idx));
384 let vb = vld1q_f32(b_ptr.add(idx));
385
386 sum = vfmaq_f32(sum, va, vb);
388
389 idx += 4;
390 }
391
392 let mut total = vaddvq_f32(sum);
393
394 for i in (chunks * 4)..len {
396 total += *a.get_unchecked(i) * *b.get_unchecked(i);
397 }
398
399 total
400}
401
402#[cfg(target_arch = "aarch64")]
407#[inline(always)]
408unsafe fn cosine_similarity_neon_impl(a: &[f32], b: &[f32]) -> f32 {
409 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
410
411 let len = a.len();
412 let mut dot = vdupq_n_f32(0.0);
413 let mut norm_a = vdupq_n_f32(0.0);
414 let mut norm_b = vdupq_n_f32(0.0);
415
416 let a_ptr = a.as_ptr();
417 let b_ptr = b.as_ptr();
418
419 let chunks = len / 4;
420 let mut idx = 0usize;
421
422 for _ in 0..chunks {
423 let va = vld1q_f32(a_ptr.add(idx));
424 let vb = vld1q_f32(b_ptr.add(idx));
425
426 dot = vfmaq_f32(dot, va, vb);
428
429 norm_a = vfmaq_f32(norm_a, va, va);
431 norm_b = vfmaq_f32(norm_b, vb, vb);
432
433 idx += 4;
434 }
435
436 let mut dot_sum = vaddvq_f32(dot);
437 let mut norm_a_sum = vaddvq_f32(norm_a);
438 let mut norm_b_sum = vaddvq_f32(norm_b);
439
440 for i in (chunks * 4)..len {
442 let ai = *a.get_unchecked(i);
443 let bi = *b.get_unchecked(i);
444 dot_sum += ai * bi;
445 norm_a_sum += ai * ai;
446 norm_b_sum += bi * bi;
447 }
448
449 dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
450}
451
452#[cfg(target_arch = "aarch64")]
457#[inline(always)]
458unsafe fn manhattan_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
459 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
460
461 let len = a.len();
462 let mut sum = vdupq_n_f32(0.0);
463
464 let a_ptr = a.as_ptr();
465 let b_ptr = b.as_ptr();
466
467 let chunks = len / 4;
468 let mut idx = 0usize;
469
470 for _ in 0..chunks {
471 let va = vld1q_f32(a_ptr.add(idx));
472 let vb = vld1q_f32(b_ptr.add(idx));
473
474 let abs_diff = vabdq_f32(va, vb);
476 sum = vaddq_f32(sum, abs_diff);
477
478 idx += 4;
479 }
480
481 let mut total = vaddvq_f32(sum);
482
483 for i in (chunks * 4)..len {
485 total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
486 }
487
488 total
489}
490
491#[cfg(target_arch = "aarch64")]
502#[inline(always)]
503unsafe fn euclidean_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
504 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
505
506 let len = a.len();
507 let a_ptr = a.as_ptr();
508 let b_ptr = b.as_ptr();
509
510 let mut sum0 = vdupq_n_f32(0.0);
512 let mut sum1 = vdupq_n_f32(0.0);
513 let mut sum2 = vdupq_n_f32(0.0);
514 let mut sum3 = vdupq_n_f32(0.0);
515
516 let chunks = len / 16;
518 let mut idx = 0usize;
519
520 for _ in 0..chunks {
521 let va0 = vld1q_f32(a_ptr.add(idx));
523 let vb0 = vld1q_f32(b_ptr.add(idx));
524 let diff0 = vsubq_f32(va0, vb0);
525 sum0 = vfmaq_f32(sum0, diff0, diff0);
526
527 let va1 = vld1q_f32(a_ptr.add(idx + 4));
528 let vb1 = vld1q_f32(b_ptr.add(idx + 4));
529 let diff1 = vsubq_f32(va1, vb1);
530 sum1 = vfmaq_f32(sum1, diff1, diff1);
531
532 let va2 = vld1q_f32(a_ptr.add(idx + 8));
533 let vb2 = vld1q_f32(b_ptr.add(idx + 8));
534 let diff2 = vsubq_f32(va2, vb2);
535 sum2 = vfmaq_f32(sum2, diff2, diff2);
536
537 let va3 = vld1q_f32(a_ptr.add(idx + 12));
538 let vb3 = vld1q_f32(b_ptr.add(idx + 12));
539 let diff3 = vsubq_f32(va3, vb3);
540 sum3 = vfmaq_f32(sum3, diff3, diff3);
541
542 idx += 16;
543 }
544
545 let sum01 = vaddq_f32(sum0, sum1);
547 let sum23 = vaddq_f32(sum2, sum3);
548 let sum = vaddq_f32(sum01, sum23);
549
550 let remaining_start = chunks * 16;
552 let remaining_chunks = (len - remaining_start) / 4;
553 let mut final_sum = sum;
554
555 idx = remaining_start;
556 for _ in 0..remaining_chunks {
557 let va = vld1q_f32(a_ptr.add(idx));
558 let vb = vld1q_f32(b_ptr.add(idx));
559 let diff = vsubq_f32(va, vb);
560 final_sum = vfmaq_f32(final_sum, diff, diff);
561 idx += 4;
562 }
563
564 let mut total = vaddvq_f32(final_sum);
566
567 let scalar_start = remaining_start + remaining_chunks * 4;
569 for i in scalar_start..len {
570 let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
571 total += diff * diff;
572 }
573
574 total.sqrt()
575}
576
577#[cfg(target_arch = "aarch64")]
582#[inline(always)]
583unsafe fn dot_product_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
584 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
585
586 let len = a.len();
587 let a_ptr = a.as_ptr();
588 let b_ptr = b.as_ptr();
589
590 let mut sum0 = vdupq_n_f32(0.0);
591 let mut sum1 = vdupq_n_f32(0.0);
592 let mut sum2 = vdupq_n_f32(0.0);
593 let mut sum3 = vdupq_n_f32(0.0);
594
595 let chunks = len / 16;
596 let mut idx = 0usize;
597
598 for _ in 0..chunks {
599 let va0 = vld1q_f32(a_ptr.add(idx));
600 let vb0 = vld1q_f32(b_ptr.add(idx));
601 sum0 = vfmaq_f32(sum0, va0, vb0);
602
603 let va1 = vld1q_f32(a_ptr.add(idx + 4));
604 let vb1 = vld1q_f32(b_ptr.add(idx + 4));
605 sum1 = vfmaq_f32(sum1, va1, vb1);
606
607 let va2 = vld1q_f32(a_ptr.add(idx + 8));
608 let vb2 = vld1q_f32(b_ptr.add(idx + 8));
609 sum2 = vfmaq_f32(sum2, va2, vb2);
610
611 let va3 = vld1q_f32(a_ptr.add(idx + 12));
612 let vb3 = vld1q_f32(b_ptr.add(idx + 12));
613 sum3 = vfmaq_f32(sum3, va3, vb3);
614
615 idx += 16;
616 }
617
618 let sum01 = vaddq_f32(sum0, sum1);
620 let sum23 = vaddq_f32(sum2, sum3);
621 let sum = vaddq_f32(sum01, sum23);
622
623 let remaining_start = chunks * 16;
624 let remaining_chunks = (len - remaining_start) / 4;
625 let mut final_sum = sum;
626
627 idx = remaining_start;
628 for _ in 0..remaining_chunks {
629 let va = vld1q_f32(a_ptr.add(idx));
630 let vb = vld1q_f32(b_ptr.add(idx));
631 final_sum = vfmaq_f32(final_sum, va, vb);
632 idx += 4;
633 }
634
635 let mut total = vaddvq_f32(final_sum);
636
637 let scalar_start = remaining_start + remaining_chunks * 4;
639 for i in scalar_start..len {
640 total += *a.get_unchecked(i) * *b.get_unchecked(i);
641 }
642
643 total
644}
645
646#[cfg(target_arch = "aarch64")]
651#[inline(always)]
652unsafe fn cosine_similarity_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
653 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
654
655 let len = a.len();
656 let a_ptr = a.as_ptr();
657 let b_ptr = b.as_ptr();
658
659 let mut dot0 = vdupq_n_f32(0.0);
660 let mut dot1 = vdupq_n_f32(0.0);
661 let mut norm_a0 = vdupq_n_f32(0.0);
662 let mut norm_a1 = vdupq_n_f32(0.0);
663 let mut norm_b0 = vdupq_n_f32(0.0);
664 let mut norm_b1 = vdupq_n_f32(0.0);
665
666 let chunks = len / 8;
667 let mut idx = 0usize;
668
669 for _ in 0..chunks {
670 let va0 = vld1q_f32(a_ptr.add(idx));
671 let vb0 = vld1q_f32(b_ptr.add(idx));
672 dot0 = vfmaq_f32(dot0, va0, vb0);
673 norm_a0 = vfmaq_f32(norm_a0, va0, va0);
674 norm_b0 = vfmaq_f32(norm_b0, vb0, vb0);
675
676 let va1 = vld1q_f32(a_ptr.add(idx + 4));
677 let vb1 = vld1q_f32(b_ptr.add(idx + 4));
678 dot1 = vfmaq_f32(dot1, va1, vb1);
679 norm_a1 = vfmaq_f32(norm_a1, va1, va1);
680 norm_b1 = vfmaq_f32(norm_b1, vb1, vb1);
681
682 idx += 8;
683 }
684
685 let dot = vaddq_f32(dot0, dot1);
687 let norm_a = vaddq_f32(norm_a0, norm_a1);
688 let norm_b = vaddq_f32(norm_b0, norm_b1);
689
690 let mut dot_sum = vaddvq_f32(dot);
691 let mut norm_a_sum = vaddvq_f32(norm_a);
692 let mut norm_b_sum = vaddvq_f32(norm_b);
693
694 for i in (chunks * 8)..len {
696 let ai = *a.get_unchecked(i);
697 let bi = *b.get_unchecked(i);
698 dot_sum += ai * bi;
699 norm_a_sum += ai * ai;
700 norm_b_sum += bi * bi;
701 }
702
703 dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
704}
705
706#[cfg(target_arch = "aarch64")]
711#[inline(always)]
712unsafe fn manhattan_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
713 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
714
715 let len = a.len();
716 let a_ptr = a.as_ptr();
717 let b_ptr = b.as_ptr();
718
719 let mut sum0 = vdupq_n_f32(0.0);
720 let mut sum1 = vdupq_n_f32(0.0);
721 let mut sum2 = vdupq_n_f32(0.0);
722 let mut sum3 = vdupq_n_f32(0.0);
723
724 let chunks = len / 16;
725 let mut idx = 0usize;
726
727 for _ in 0..chunks {
728 let va0 = vld1q_f32(a_ptr.add(idx));
730 let vb0 = vld1q_f32(b_ptr.add(idx));
731 sum0 = vaddq_f32(sum0, vabdq_f32(va0, vb0));
732
733 let va1 = vld1q_f32(a_ptr.add(idx + 4));
734 let vb1 = vld1q_f32(b_ptr.add(idx + 4));
735 sum1 = vaddq_f32(sum1, vabdq_f32(va1, vb1));
736
737 let va2 = vld1q_f32(a_ptr.add(idx + 8));
738 let vb2 = vld1q_f32(b_ptr.add(idx + 8));
739 sum2 = vaddq_f32(sum2, vabdq_f32(va2, vb2));
740
741 let va3 = vld1q_f32(a_ptr.add(idx + 12));
742 let vb3 = vld1q_f32(b_ptr.add(idx + 12));
743 sum3 = vaddq_f32(sum3, vabdq_f32(va3, vb3));
744
745 idx += 16;
746 }
747
748 let sum01 = vaddq_f32(sum0, sum1);
750 let sum23 = vaddq_f32(sum2, sum3);
751 let sum = vaddq_f32(sum01, sum23);
752
753 let remaining_start = chunks * 16;
754 let remaining_chunks = (len - remaining_start) / 4;
755 let mut final_sum = sum;
756
757 idx = remaining_start;
758 for _ in 0..remaining_chunks {
759 let va = vld1q_f32(a_ptr.add(idx));
760 let vb = vld1q_f32(b_ptr.add(idx));
761 final_sum = vaddq_f32(final_sum, vabdq_f32(va, vb));
762 idx += 4;
763 }
764
765 let mut total = vaddvq_f32(final_sum);
766
767 let scalar_start = remaining_start + remaining_chunks * 4;
769 for i in scalar_start..len {
770 total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
771 }
772
773 total
774}
775
776#[inline(always)]
783pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
784 #[cfg(target_arch = "x86_64")]
785 {
786 if is_x86_feature_detected!("avx512f") {
787 unsafe { dot_product_avx512_impl(a, b) }
788 } else if is_x86_feature_detected!("avx2") {
789 unsafe { dot_product_avx2_impl(a, b) }
790 } else {
791 dot_product_scalar(a, b)
792 }
793 }
794
795 #[cfg(target_arch = "aarch64")]
796 {
797 if a.len() >= 64 {
798 unsafe { dot_product_neon_unrolled_impl(a, b) }
799 } else {
800 unsafe { dot_product_neon_impl(a, b) }
801 }
802 }
803
804 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
805 {
806 dot_product_scalar(a, b)
807 }
808}
809
810#[inline(always)]
812pub fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
813 dot_product_simd(a, b)
814}
815
816#[cfg(target_arch = "x86_64")]
817#[target_feature(enable = "avx2")]
818unsafe fn dot_product_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
819 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
821
822 let len = a.len();
823 let mut sum = _mm256_setzero_ps();
824
825 let chunks = len / 8;
826 for i in 0..chunks {
827 let idx = i * 8;
828 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
829 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
830 let prod = _mm256_mul_ps(va, vb);
831 sum = _mm256_add_ps(sum, prod);
832 }
833
834 let sum_arr: [f32; 8] = std::mem::transmute(sum);
835 let mut total = sum_arr.iter().sum::<f32>();
836
837 for i in (chunks * 8)..len {
838 total += a[i] * b[i];
839 }
840
841 total
842}
843
844#[inline(always)]
847pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
848 #[cfg(target_arch = "x86_64")]
849 {
850 if is_x86_feature_detected!("avx512f") {
851 unsafe { cosine_similarity_avx512_impl(a, b) }
852 } else if is_x86_feature_detected!("avx2") {
853 unsafe { cosine_similarity_avx2_impl(a, b) }
854 } else {
855 cosine_similarity_scalar(a, b)
856 }
857 }
858
859 #[cfg(target_arch = "aarch64")]
860 {
861 if a.len() >= 64 {
862 unsafe { cosine_similarity_neon_unrolled_impl(a, b) }
863 } else {
864 unsafe { cosine_similarity_neon_impl(a, b) }
865 }
866 }
867
868 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
869 {
870 cosine_similarity_scalar(a, b)
871 }
872}
873
874#[inline(always)]
876pub fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
877 cosine_similarity_simd(a, b)
878}
879
880#[inline(always)]
883pub fn manhattan_distance_simd(a: &[f32], b: &[f32]) -> f32 {
884 #[cfg(target_arch = "x86_64")]
885 {
886 if is_x86_feature_detected!("avx512f") {
887 unsafe { manhattan_distance_avx512_impl(a, b) }
888 } else {
889 manhattan_distance_scalar(a, b)
890 }
891 }
892
893 #[cfg(target_arch = "aarch64")]
894 {
895 if a.len() >= 64 {
896 unsafe { manhattan_distance_neon_unrolled_impl(a, b) }
897 } else {
898 unsafe { manhattan_distance_neon_impl(a, b) }
899 }
900 }
901
902 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
903 {
904 manhattan_distance_scalar(a, b)
905 }
906}
907
908#[cfg(target_arch = "x86_64")]
909#[target_feature(enable = "avx2")]
910unsafe fn cosine_similarity_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
911 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
913
914 let len = a.len();
915 let mut dot = _mm256_setzero_ps();
916 let mut norm_a = _mm256_setzero_ps();
917 let mut norm_b = _mm256_setzero_ps();
918
919 let chunks = len / 8;
920 for i in 0..chunks {
921 let idx = i * 8;
922 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
923 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
924
925 dot = _mm256_add_ps(dot, _mm256_mul_ps(va, vb));
927
928 norm_a = _mm256_add_ps(norm_a, _mm256_mul_ps(va, va));
930 norm_b = _mm256_add_ps(norm_b, _mm256_mul_ps(vb, vb));
931 }
932
933 let dot_arr: [f32; 8] = std::mem::transmute(dot);
934 let norm_a_arr: [f32; 8] = std::mem::transmute(norm_a);
935 let norm_b_arr: [f32; 8] = std::mem::transmute(norm_b);
936
937 let mut dot_sum = dot_arr.iter().sum::<f32>();
938 let mut norm_a_sum = norm_a_arr.iter().sum::<f32>();
939 let mut norm_b_sum = norm_b_arr.iter().sum::<f32>();
940
941 for i in (chunks * 8)..len {
942 dot_sum += a[i] * b[i];
943 norm_a_sum += a[i] * a[i];
944 norm_b_sum += b[i] * b[i];
945 }
946
947 dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
948}
949
950fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
953 a.iter()
954 .zip(b.iter())
955 .map(|(x, y)| {
956 let diff = x - y;
957 diff * diff
958 })
959 .sum::<f32>()
960 .sqrt()
961}
962
963fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
964 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
965}
966
967fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
968 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
969 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
970 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
971 dot / (norm_a * norm_b)
972}
973
974fn manhattan_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
975 a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
976}
977
978#[inline(always)]
985pub fn dot_product_i8(a: &[i8], b: &[i8]) -> i32 {
986 #[cfg(target_arch = "x86_64")]
987 {
988 if is_x86_feature_detected!("avx2") {
989 unsafe { dot_product_i8_avx2_impl(a, b) }
990 } else {
991 dot_product_i8_scalar(a, b)
992 }
993 }
994
995 #[cfg(target_arch = "aarch64")]
996 {
997 unsafe { dot_product_i8_neon_impl(a, b) }
998 }
999
1000 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1001 {
1002 dot_product_i8_scalar(a, b)
1003 }
1004}
1005
1006#[inline(always)]
1009pub fn euclidean_distance_squared_i8(a: &[i8], b: &[i8]) -> i32 {
1010 #[cfg(target_arch = "x86_64")]
1011 {
1012 if is_x86_feature_detected!("avx2") {
1013 unsafe { euclidean_distance_squared_i8_avx2_impl(a, b) }
1014 } else {
1015 euclidean_distance_squared_i8_scalar(a, b)
1016 }
1017 }
1018
1019 #[cfg(target_arch = "aarch64")]
1020 {
1021 unsafe { euclidean_distance_squared_i8_neon_impl(a, b) }
1022 }
1023
1024 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1025 {
1026 euclidean_distance_squared_i8_scalar(a, b)
1027 }
1028}
1029
1030#[cfg(target_arch = "aarch64")]
1036#[inline(always)]
1037unsafe fn dot_product_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
1038 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1039
1040 let len = a.len();
1041 let a_ptr = a.as_ptr();
1042 let b_ptr = b.as_ptr();
1043
1044 let mut sum = vdupq_n_s32(0);
1045
1046 let chunks = len / 8;
1048 let mut idx = 0usize;
1049
1050 for _ in 0..chunks {
1051 let va = vld1_s8(a_ptr.add(idx));
1052 let vb = vld1_s8(b_ptr.add(idx));
1053
1054 let va_i16 = vmovl_s8(va);
1056 let vb_i16 = vmovl_s8(vb);
1057
1058 let prod_lo = vmull_s16(vget_low_s16(va_i16), vget_low_s16(vb_i16));
1060 let prod_hi = vmull_s16(vget_high_s16(va_i16), vget_high_s16(vb_i16));
1061
1062 sum = vaddq_s32(sum, prod_lo);
1064 sum = vaddq_s32(sum, prod_hi);
1065
1066 idx += 8;
1067 }
1068
1069 let mut total = vaddvq_s32(sum);
1071
1072 for i in (chunks * 8)..len {
1074 total += (*a.get_unchecked(i) as i32) * (*b.get_unchecked(i) as i32);
1075 }
1076
1077 total
1078}
1079
1080#[cfg(target_arch = "aarch64")]
1085#[inline(always)]
1086unsafe fn euclidean_distance_squared_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
1087 debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1088
1089 let len = a.len();
1090 let a_ptr = a.as_ptr();
1091 let b_ptr = b.as_ptr();
1092
1093 let mut sum = vdupq_n_s32(0);
1094
1095 let chunks = len / 8;
1097 let mut idx = 0usize;
1098
1099 for _ in 0..chunks {
1100 let va = vld1_s8(a_ptr.add(idx));
1101 let vb = vld1_s8(b_ptr.add(idx));
1102
1103 let va_i16 = vmovl_s8(va);
1105 let vb_i16 = vmovl_s8(vb);
1106
1107 let diff = vsubq_s16(va_i16, vb_i16);
1109
1110 let prod_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
1112 let prod_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
1113
1114 sum = vaddq_s32(sum, prod_lo);
1115 sum = vaddq_s32(sum, prod_hi);
1116
1117 idx += 8;
1118 }
1119
1120 let mut total = vaddvq_s32(sum);
1121
1122 for i in (chunks * 8)..len {
1124 let diff = (*a.get_unchecked(i) as i32) - (*b.get_unchecked(i) as i32);
1125 total += diff * diff;
1126 }
1127
1128 total
1129}
1130
1131#[cfg(target_arch = "x86_64")]
1133#[target_feature(enable = "avx2")]
1134unsafe fn dot_product_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
1135 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1136
1137 let len = a.len();
1138 let mut sum = _mm256_setzero_si256();
1139
1140 let chunks = len / 32;
1142 for i in 0..chunks {
1143 let idx = i * 32;
1144 let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
1145 let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
1146
1147 let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
1149 let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
1150 let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
1151 let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
1152
1153 let prod_lo = _mm256_madd_epi16(va_lo, vb_lo);
1154 let prod_hi = _mm256_madd_epi16(va_hi, vb_hi);
1155
1156 sum = _mm256_add_epi32(sum, prod_lo);
1157 sum = _mm256_add_epi32(sum, prod_hi);
1158 }
1159
1160 let sum_arr: [i32; 8] = std::mem::transmute(sum);
1162 let mut total: i32 = sum_arr.iter().sum();
1163
1164 for i in (chunks * 32)..len {
1166 total += (a[i] as i32) * (b[i] as i32);
1167 }
1168
1169 total
1170}
1171
1172#[cfg(target_arch = "x86_64")]
1174#[target_feature(enable = "avx2")]
1175unsafe fn euclidean_distance_squared_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
1176 assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
1177
1178 let len = a.len();
1179 let mut sum = _mm256_setzero_si256();
1180
1181 let chunks = len / 32;
1182 for i in 0..chunks {
1183 let idx = i * 32;
1184 let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
1185 let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
1186
1187 let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
1189 let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
1190 let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
1191 let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
1192
1193 let diff_lo = _mm256_sub_epi16(va_lo, vb_lo);
1194 let diff_hi = _mm256_sub_epi16(va_hi, vb_hi);
1195
1196 let sq_lo = _mm256_madd_epi16(diff_lo, diff_lo);
1197 let sq_hi = _mm256_madd_epi16(diff_hi, diff_hi);
1198
1199 sum = _mm256_add_epi32(sum, sq_lo);
1200 sum = _mm256_add_epi32(sum, sq_hi);
1201 }
1202
1203 let sum_arr: [i32; 8] = std::mem::transmute(sum);
1204 let mut total: i32 = sum_arr.iter().sum();
1205
1206 for i in (chunks * 32)..len {
1207 let diff = (a[i] as i32) - (b[i] as i32);
1208 total += diff * diff;
1209 }
1210
1211 total
1212}
1213
1214fn dot_product_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
1216 a.iter()
1217 .zip(b.iter())
1218 .map(|(&x, &y)| (x as i32) * (y as i32))
1219 .sum()
1220}
1221
1222fn euclidean_distance_squared_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
1224 a.iter()
1225 .zip(b.iter())
1226 .map(|(&x, &y)| {
1227 let diff = (x as i32) - (y as i32);
1228 diff * diff
1229 })
1230 .sum()
1231}
1232
1233#[inline]
1241pub fn batch_dot_product(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1242 assert_eq!(
1243 vectors.len(),
1244 results.len(),
1245 "Output size must match vector count"
1246 );
1247
1248 const TILE_SIZE: usize = 16;
1250
1251 for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1252 let base_idx = chunk_idx * TILE_SIZE;
1253 for (i, vec) in chunk.iter().enumerate() {
1254 results[base_idx + i] = dot_product_simd(query, vec);
1255 }
1256 }
1257}
1258
1259#[inline]
1263pub fn batch_euclidean(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1264 assert_eq!(
1265 vectors.len(),
1266 results.len(),
1267 "Output size must match vector count"
1268 );
1269
1270 const TILE_SIZE: usize = 16;
1271
1272 for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1273 let base_idx = chunk_idx * TILE_SIZE;
1274 for (i, vec) in chunk.iter().enumerate() {
1275 results[base_idx + i] = euclidean_distance_simd(query, vec);
1276 }
1277 }
1278}
1279
1280#[inline]
1282pub fn batch_cosine_similarity(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
1283 assert_eq!(
1284 vectors.len(),
1285 results.len(),
1286 "Output size must match vector count"
1287 );
1288
1289 const TILE_SIZE: usize = 16;
1290
1291 for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
1292 let base_idx = chunk_idx * TILE_SIZE;
1293 for (i, vec) in chunk.iter().enumerate() {
1294 results[base_idx + i] = cosine_similarity_simd(query, vec);
1295 }
1296 }
1297}
1298
1299#[inline]
1301pub fn batch_dot_product_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
1302 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1303 let mut results = vec![0.0; vectors.len()];
1304 batch_dot_product(query, &refs, &mut results);
1305 results
1306}
1307
1308#[inline]
1310pub fn batch_euclidean_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
1311 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1312 let mut results = vec![0.0; vectors.len()];
1313 batch_euclidean(query, &refs, &mut results);
1314 results
1315}
1316
1317#[cfg(test)]
1318mod tests {
1319 use super::*;
1320
1321 #[test]
1322 fn test_euclidean_distance_simd() {
1323 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1324 let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
1325
1326 let result = euclidean_distance_simd(&a, &b);
1327 let expected = euclidean_distance_scalar(&a, &b);
1328
1329 assert!(
1330 (result - expected).abs() < 0.001,
1331 "SIMD result {} differs from scalar result {}",
1332 result,
1333 expected
1334 );
1335 }
1336
1337 #[test]
1338 fn test_euclidean_distance_large() {
1339 let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
1341 let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
1342
1343 let result = euclidean_distance_simd(&a, &b);
1344 let expected = euclidean_distance_scalar(&a, &b);
1345
1346 assert!(
1347 (result - expected).abs() < 0.01,
1348 "Large vector: SIMD {} vs scalar {}",
1349 result,
1350 expected
1351 );
1352 }
1353
1354 #[test]
1355 fn test_dot_product_simd() {
1356 let a = vec![1.0; 16];
1357 let b = vec![2.0; 16];
1358
1359 let result = dot_product_simd(&a, &b);
1360 assert!((result - 32.0).abs() < 0.001);
1361 }
1362
1363 #[test]
1364 fn test_dot_product_large() {
1365 let a: Vec<f32> = (0..256).map(|i| (i % 10) as f32).collect();
1366 let b: Vec<f32> = (0..256).map(|i| ((i + 5) % 10) as f32).collect();
1367
1368 let result = dot_product_simd(&a, &b);
1369 let expected = dot_product_scalar(&a, &b);
1370
1371 assert!(
1372 (result - expected).abs() < 0.1,
1373 "Large dot product: SIMD {} vs scalar {}",
1374 result,
1375 expected
1376 );
1377 }
1378
1379 #[test]
1380 fn test_cosine_similarity_simd() {
1381 let a = vec![1.0, 0.0, 0.0];
1382 let b = vec![1.0, 0.0, 0.0];
1383
1384 let result = cosine_similarity_simd(&a, &b);
1385 assert!((result - 1.0).abs() < 0.001);
1386 }
1387
1388 #[test]
1389 fn test_cosine_similarity_orthogonal() {
1390 let a = vec![1.0, 0.0, 0.0, 0.0];
1391 let b = vec![0.0, 1.0, 0.0, 0.0];
1392
1393 let result = cosine_similarity_simd(&a, &b);
1394 assert!(
1395 result.abs() < 0.001,
1396 "Orthogonal vectors should have ~0 similarity, got {}",
1397 result
1398 );
1399 }
1400
1401 #[test]
1402 fn test_manhattan_distance_simd() {
1403 let a = vec![1.0, 2.0, 3.0, 4.0];
1404 let b = vec![5.0, 6.0, 7.0, 8.0];
1405
1406 let result = manhattan_distance_simd(&a, &b);
1407 let expected = manhattan_distance_scalar(&a, &b);
1408
1409 assert!(
1410 (result - expected).abs() < 0.001,
1411 "Manhattan: SIMD {} vs scalar {}",
1412 result,
1413 expected
1414 );
1415 assert!((result - 16.0).abs() < 0.001); }
1417
1418 #[test]
1419 fn test_non_aligned_lengths() {
1420 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]; let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1423
1424 let result = euclidean_distance_simd(&a, &b);
1425 let expected = euclidean_distance_scalar(&a, &b);
1426
1427 assert!(
1428 (result - expected).abs() < 0.001,
1429 "Non-aligned: SIMD {} vs scalar {}",
1430 result,
1431 expected
1432 );
1433 }
1434
1435 #[test]
1437 fn test_legacy_avx2_aliases() {
1438 let a = vec![1.0, 2.0, 3.0, 4.0];
1439 let b = vec![5.0, 6.0, 7.0, 8.0];
1440
1441 let _ = euclidean_distance_avx2(&a, &b);
1443 let _ = dot_product_avx2(&a, &b);
1444 let _ = cosine_similarity_avx2(&a, &b);
1445 }
1446
1447 #[test]
1449 fn test_dot_product_i8() {
1450 let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1451 let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
1452
1453 let result = dot_product_i8(&a, &b);
1454 let expected = dot_product_i8_scalar(&a, &b);
1455
1456 assert_eq!(
1457 result, expected,
1458 "INT8 dot product: SIMD {} vs scalar {}",
1459 result, expected
1460 );
1461 }
1462
1463 #[test]
1464 fn test_dot_product_i8_large() {
1465 let a: Vec<i8> = (0..128)
1467 .map(|i| ((i % 256) as i8).wrapping_sub(64))
1468 .collect();
1469 let b: Vec<i8> = (0..128)
1470 .map(|i| (((i + 10) % 256) as i8).wrapping_sub(64))
1471 .collect();
1472
1473 let result = dot_product_i8(&a, &b);
1474 let expected = dot_product_i8_scalar(&a, &b);
1475
1476 assert_eq!(
1477 result, expected,
1478 "Large INT8 dot product: SIMD {} vs scalar {}",
1479 result, expected
1480 );
1481 }
1482
1483 #[test]
1484 fn test_euclidean_distance_squared_i8() {
1485 let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1486 let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
1487
1488 let result = euclidean_distance_squared_i8(&a, &b);
1489 let expected = euclidean_distance_squared_i8_scalar(&a, &b);
1490
1491 assert_eq!(
1492 result, expected,
1493 "INT8 euclidean^2: SIMD {} vs scalar {}",
1494 result, expected
1495 );
1496 assert_eq!(result, 16, "Expected 16, got {}", result);
1498 }
1499
1500 #[test]
1501 fn test_euclidean_distance_squared_i8_large() {
1502 let a: Vec<i8> = (0..128)
1503 .map(|i| ((i % 256) as i8).wrapping_sub(64))
1504 .collect();
1505 let b: Vec<i8> = (0..128)
1506 .map(|i| (((i + 5) % 256) as i8).wrapping_sub(64))
1507 .collect();
1508
1509 let result = euclidean_distance_squared_i8(&a, &b);
1510 let expected = euclidean_distance_squared_i8_scalar(&a, &b);
1511
1512 assert_eq!(
1513 result, expected,
1514 "Large INT8 euclidean^2: SIMD {} vs scalar {}",
1515 result, expected
1516 );
1517 }
1518
1519 #[test]
1521 fn test_batch_dot_product() {
1522 let query = vec![1.0, 2.0, 3.0, 4.0];
1523 let v1 = vec![1.0, 0.0, 0.0, 0.0];
1524 let v2 = vec![0.0, 1.0, 0.0, 0.0];
1525 let v3 = vec![0.0, 0.0, 1.0, 0.0];
1526 let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
1527 let mut results = vec![0.0; 3];
1528
1529 batch_dot_product(&query, &vectors, &mut results);
1530
1531 assert!((results[0] - 1.0).abs() < 0.001);
1532 assert!((results[1] - 2.0).abs() < 0.001);
1533 assert!((results[2] - 3.0).abs() < 0.001);
1534 }
1535
1536 #[test]
1537 fn test_batch_euclidean() {
1538 let query = vec![0.0, 0.0, 0.0, 0.0];
1539 let v1 = vec![3.0, 4.0, 0.0, 0.0];
1540 let v2 = vec![0.0, 0.0, 5.0, 12.0];
1541 let vectors: Vec<&[f32]> = vec![&v1, &v2];
1542 let mut results = vec![0.0; 2];
1543
1544 batch_euclidean(&query, &vectors, &mut results);
1545
1546 assert!(
1547 (results[0] - 5.0).abs() < 0.001,
1548 "Expected 5.0, got {}",
1549 results[0]
1550 );
1551 assert!(
1552 (results[1] - 13.0).abs() < 0.001,
1553 "Expected 13.0, got {}",
1554 results[1]
1555 );
1556 }
1557
1558 #[test]
1559 fn test_batch_cosine_similarity() {
1560 let query = vec![1.0, 0.0, 0.0, 0.0];
1561 let v1 = vec![1.0, 0.0, 0.0, 0.0]; let v2 = vec![0.0, 1.0, 0.0, 0.0]; let v3 = vec![-1.0, 0.0, 0.0, 0.0]; let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
1565 let mut results = vec![0.0; 3];
1566
1567 batch_cosine_similarity(&query, &vectors, &mut results);
1568
1569 assert!(
1570 (results[0] - 1.0).abs() < 0.001,
1571 "Same direction should be 1.0"
1572 );
1573 assert!(results[1].abs() < 0.001, "Orthogonal should be 0.0");
1574 assert!((results[2] + 1.0).abs() < 0.001, "Opposite should be -1.0");
1575 }
1576
1577 #[test]
1578 fn test_batch_owned_convenience() {
1579 let query = vec![1.0, 2.0, 3.0, 4.0];
1580 let vectors = vec![vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]];
1581
1582 let results = batch_dot_product_owned(&query, &vectors);
1583 assert_eq!(results.len(), 2);
1584 assert!((results[0] - 1.0).abs() < 0.001);
1585 assert!((results[1] - 2.0).abs() < 0.001);
1586 }
1587
1588 #[test]
1589 fn test_unrolled_vs_non_unrolled_consistency() {
1590 let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
1592 let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
1593
1594 let result = euclidean_distance_simd(&a, &b);
1595 let expected = euclidean_distance_scalar(&a, &b);
1596
1597 assert!(
1598 (result - expected).abs() < 0.01,
1599 "Unrolled consistency: SIMD {} vs scalar {}",
1600 result,
1601 expected
1602 );
1603 }
1604}