1use crate::simd_avx512;
15use crate::simd_explicit;
16
17pub const L2_CACHE_LINE_BYTES: usize = 64;
23
24#[inline]
39#[must_use]
40pub const fn calculate_prefetch_distance(dimension: usize) -> usize {
41 let vector_bytes = dimension * std::mem::size_of::<f32>();
42 let raw_distance = vector_bytes / L2_CACHE_LINE_BYTES;
43 if raw_distance < 4 {
45 4
46 } else if raw_distance > 16 {
47 16
48 } else {
49 raw_distance
50 }
51}
52
53#[inline]
80pub fn prefetch_vector(vector: &[f32]) {
81 #[cfg(target_arch = "x86_64")]
82 {
83 unsafe {
85 use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
86 _mm_prefetch(vector.as_ptr().cast::<i8>(), _MM_HINT_T0);
87 }
88 }
89
90 #[cfg(not(target_arch = "x86_64"))]
101 {
102 let _ = vector;
104 }
105}
106
107#[inline]
122#[must_use]
123pub fn cosine_similarity_fast(a: &[f32], b: &[f32]) -> f32 {
124 simd_avx512::cosine_similarity_auto(a, b)
126}
127
128#[inline]
138#[must_use]
139pub fn euclidean_distance_fast(a: &[f32], b: &[f32]) -> f32 {
140 simd_avx512::euclidean_auto(a, b)
142}
143
144#[inline]
150#[must_use]
151pub fn squared_l2_distance(a: &[f32], b: &[f32]) -> f32 {
152 simd_avx512::squared_l2_auto(a, b)
154}
155
156#[inline]
162pub fn normalize_inplace(v: &mut [f32]) {
163 simd_explicit::normalize_inplace_simd(v);
164}
165
166#[inline]
168#[must_use]
169pub fn norm(v: &[f32]) -> f32 {
170 v.iter().map(|x| x * x).sum::<f32>().sqrt()
171}
172
173#[inline]
183#[must_use]
184pub fn dot_product_fast(a: &[f32], b: &[f32]) -> f32 {
185 simd_avx512::dot_product_auto(a, b)
187}
188
189#[inline]
204#[must_use]
205pub fn cosine_similarity_normalized(a: &[f32], b: &[f32]) -> f32 {
206 simd_avx512::cosine_similarity_normalized(a, b)
207}
208
209#[must_use]
216pub fn batch_cosine_normalized(candidates: &[&[f32]], query: &[f32]) -> Vec<f32> {
217 simd_avx512::batch_cosine_normalized(candidates, query)
218}
219
220#[inline]
243#[must_use]
244pub fn hamming_distance_fast(a: &[f32], b: &[f32]) -> f32 {
245 crate::simd_explicit::hamming_distance_simd(a, b)
247}
248
249#[inline]
273#[must_use]
274pub fn jaccard_similarity_fast(a: &[f32], b: &[f32]) -> f32 {
275 crate::simd_explicit::jaccard_similarity_simd(a, b)
277}
278
279#[cfg(test)]
280#[allow(clippy::cast_precision_loss)]
281mod tests {
282 use super::*;
283
284 const EPSILON: f32 = 1e-5;
290
291 fn generate_test_vector(dim: usize, seed: f32) -> Vec<f32> {
292 #[allow(clippy::cast_precision_loss)]
293 (0..dim).map(|i| (seed + i as f32 * 0.1).sin()).collect()
294 }
295
296 #[test]
299 fn test_cosine_similarity_identical_vectors() {
300 let v = vec![1.0, 2.0, 3.0, 4.0];
301 let result = cosine_similarity_fast(&v, &v);
302 assert!(
303 (result - 1.0).abs() < EPSILON,
304 "Identical vectors should have similarity 1.0"
305 );
306 }
307
308 #[test]
309 fn test_cosine_similarity_orthogonal_vectors() {
310 let a = vec![1.0, 0.0, 0.0, 0.0];
311 let b = vec![0.0, 1.0, 0.0, 0.0];
312 let result = cosine_similarity_fast(&a, &b);
313 assert!(
314 result.abs() < EPSILON,
315 "Orthogonal vectors should have similarity 0.0"
316 );
317 }
318
319 #[test]
320 fn test_cosine_similarity_opposite_vectors() {
321 let a = vec![1.0, 2.0, 3.0, 4.0];
322 let b: Vec<f32> = a.iter().map(|x| -x).collect();
323 let result = cosine_similarity_fast(&a, &b);
324 assert!(
325 (result + 1.0).abs() < EPSILON,
326 "Opposite vectors should have similarity -1.0"
327 );
328 }
329
330 #[test]
331 fn test_cosine_similarity_zero_vector() {
332 let a = vec![1.0, 2.0, 3.0];
333 let b = vec![0.0, 0.0, 0.0];
334 let result = cosine_similarity_fast(&a, &b);
335 assert!(result.abs() < EPSILON, "Zero vector should return 0.0");
336 }
337
338 #[test]
339 fn test_euclidean_distance_identical_vectors() {
340 let v = vec![1.0, 2.0, 3.0, 4.0];
341 let result = euclidean_distance_fast(&v, &v);
342 assert!(
343 result.abs() < EPSILON,
344 "Identical vectors should have distance 0.0"
345 );
346 }
347
348 #[test]
349 fn test_euclidean_distance_known_value() {
350 let a = vec![0.0, 0.0, 0.0];
351 let b = vec![3.0, 4.0, 0.0];
352 let result = euclidean_distance_fast(&a, &b);
353 assert!(
354 (result - 5.0).abs() < EPSILON,
355 "Expected distance 5.0 (3-4-5 triangle)"
356 );
357 }
358
359 #[test]
360 fn test_euclidean_distance_768d() {
361 let a = generate_test_vector(768, 0.0);
362 let b = generate_test_vector(768, 1.0);
363
364 let result = euclidean_distance_fast(&a, &b);
365
366 let expected: f32 = a
368 .iter()
369 .zip(&b)
370 .map(|(x, y)| (x - y).powi(2))
371 .sum::<f32>()
372 .sqrt();
373
374 assert!(
375 (result - expected).abs() < EPSILON,
376 "Should match naive implementation"
377 );
378 }
379
380 #[test]
381 fn test_dot_product_fast_correctness() {
382 let a = vec![1.0, 2.0, 3.0, 4.0];
383 let b = vec![5.0, 6.0, 7.0, 8.0];
384 let result = dot_product_fast(&a, &b);
385 let expected = 1.0 * 5.0 + 2.0 * 6.0 + 3.0 * 7.0 + 4.0 * 8.0; assert!((result - expected).abs() < EPSILON);
387 }
388
389 #[test]
390 fn test_dot_product_fast_768d() {
391 let a = generate_test_vector(768, 0.0);
392 let b = generate_test_vector(768, 1.0);
393
394 let result = dot_product_fast(&a, &b);
395 let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
396
397 let rel_error = (result - expected).abs() / expected.abs().max(1.0);
399 assert!(rel_error < 1e-4, "Relative error too high: {rel_error}");
400 }
401
402 #[test]
403 fn test_normalize_inplace_unit_vector() {
404 let mut v = vec![3.0, 4.0, 0.0];
405 normalize_inplace(&mut v);
406
407 let norm_after: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
408 assert!(
409 (norm_after - 1.0).abs() < EPSILON,
410 "Normalized vector should have unit norm"
411 );
412 assert!((v[0] - 0.6).abs() < EPSILON, "Expected 3/5 = 0.6");
413 assert!((v[1] - 0.8).abs() < EPSILON, "Expected 4/5 = 0.8");
414 }
415
416 #[test]
417 fn test_normalize_inplace_zero_vector() {
418 let mut v = vec![0.0, 0.0, 0.0];
419 normalize_inplace(&mut v);
420 assert!(v.iter().all(|&x| x == 0.0));
422 }
423
424 #[test]
425 fn test_normalize_inplace_768d() {
426 let mut v = generate_test_vector(768, 0.0);
427 normalize_inplace(&mut v);
428
429 let norm_after: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
430 assert!(
431 (norm_after - 1.0).abs() < EPSILON,
432 "Should be unit vector after normalization"
433 );
434 }
435
436 #[test]
439 fn test_cosine_consistency_with_baseline() {
440 let a = generate_test_vector(768, 0.0);
441 let b = generate_test_vector(768, 1.0);
442
443 let dot: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
445 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
446 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
447 let baseline = dot / (norm_a * norm_b);
448
449 let fast = cosine_similarity_fast(&a, &b);
451
452 assert!(
453 (fast - baseline).abs() < EPSILON,
454 "Fast implementation should match baseline: {fast} vs {baseline}"
455 );
456 }
457
458 #[test]
461 fn test_odd_dimension_vectors() {
462 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let b = vec![5.0, 4.0, 3.0, 2.0, 1.0];
465
466 let dot = dot_product_fast(&a, &b);
467 let expected = 1.0 * 5.0 + 2.0 * 4.0 + 3.0 * 3.0 + 4.0 * 2.0 + 5.0 * 1.0; assert!((dot - expected).abs() < EPSILON);
469
470 let dist = euclidean_distance_fast(&a, &b);
471 let expected_dist: f32 = a
472 .iter()
473 .zip(&b)
474 .map(|(x, y)| (x - y).powi(2))
475 .sum::<f32>()
476 .sqrt();
477 assert!((dist - expected_dist).abs() < EPSILON);
478 }
479
480 #[test]
481 fn test_small_vectors() {
482 let a = vec![3.0];
484 let b = vec![4.0];
485 assert!((dot_product_fast(&a, &b) - 12.0).abs() < EPSILON);
486 assert!((euclidean_distance_fast(&a, &b) - 1.0).abs() < EPSILON);
487
488 let a = vec![1.0, 0.0];
490 let b = vec![0.0, 1.0];
491 assert!((cosine_similarity_fast(&a, &b)).abs() < EPSILON);
492 }
493
494 #[test]
495 #[should_panic(expected = "Vector dimensions must match")]
496 fn test_dimension_mismatch_panics() {
497 let a = vec![1.0, 2.0, 3.0];
498 let b = vec![1.0, 2.0];
499 let _ = cosine_similarity_fast(&a, &b);
500 }
501
502 #[test]
505 fn test_norm_zero_vector() {
506 let v = vec![0.0, 0.0, 0.0];
507 assert!(norm(&v).abs() < EPSILON);
508 }
509
510 #[test]
511 fn test_norm_unit_vector() {
512 let v = vec![1.0, 0.0, 0.0];
513 assert!((norm(&v) - 1.0).abs() < EPSILON);
514 }
515
516 #[test]
517 fn test_norm_known_value() {
518 let v = vec![3.0, 4.0];
519 assert!((norm(&v) - 5.0).abs() < EPSILON);
520 }
521
522 #[test]
525 fn test_squared_l2_identical() {
526 let v = vec![1.0, 2.0, 3.0];
527 assert!(squared_l2_distance(&v, &v).abs() < EPSILON);
528 }
529
530 #[test]
531 fn test_squared_l2_known_value() {
532 let a = vec![0.0, 0.0];
533 let b = vec![3.0, 4.0];
534 assert!((squared_l2_distance(&a, &b) - 25.0).abs() < EPSILON);
535 }
536
537 #[test]
540 fn test_hamming_identical() {
541 let a = vec![1.0, 0.0, 1.0, 0.0];
542 assert!(hamming_distance_fast(&a, &a).abs() < EPSILON);
543 }
544
545 #[test]
546 fn test_hamming_all_different() {
547 let a = vec![1.0, 0.0, 1.0, 0.0];
548 let b = vec![0.0, 1.0, 0.0, 1.0];
549 assert!((hamming_distance_fast(&a, &b) - 4.0).abs() < EPSILON);
550 }
551
552 #[test]
553 fn test_hamming_partial() {
554 let a = vec![1.0, 1.0, 0.0, 0.0];
555 let b = vec![1.0, 0.0, 0.0, 1.0];
556 assert!((hamming_distance_fast(&a, &b) - 2.0).abs() < EPSILON);
557 }
558
559 #[test]
560 fn test_hamming_odd_dimension() {
561 let a = vec![1.0, 0.0, 1.0, 0.0, 1.0];
562 let b = vec![0.0, 0.0, 1.0, 1.0, 1.0];
563 assert!((hamming_distance_fast(&a, &b) - 2.0).abs() < EPSILON);
564 }
565
566 #[test]
569 fn test_jaccard_identical() {
570 let a = vec![1.0, 0.0, 1.0, 0.0];
571 assert!((jaccard_similarity_fast(&a, &a) - 1.0).abs() < EPSILON);
572 }
573
574 #[test]
575 fn test_jaccard_disjoint() {
576 let a = vec![1.0, 0.0, 0.0, 0.0];
577 let b = vec![0.0, 1.0, 0.0, 0.0];
578 assert!(jaccard_similarity_fast(&a, &b).abs() < EPSILON);
579 }
580
581 #[test]
582 fn test_jaccard_half_overlap() {
583 let a = vec![1.0, 1.0, 0.0, 0.0];
584 let b = vec![1.0, 0.0, 1.0, 0.0];
585 assert!((jaccard_similarity_fast(&a, &b) - (1.0 / 3.0)).abs() < EPSILON);
587 }
588
589 #[test]
590 fn test_jaccard_empty_sets() {
591 let a = vec![0.0, 0.0, 0.0, 0.0];
592 let b = vec![0.0, 0.0, 0.0, 0.0];
593 assert!((jaccard_similarity_fast(&a, &b) - 1.0).abs() < EPSILON);
594 }
595
596 #[test]
601 fn test_jaccard_simd_large_vectors() {
602 let a: Vec<f32> = (0..768)
604 .map(|i| if i % 2 == 0 { 1.0 } else { 0.0 })
605 .collect();
606 let b: Vec<f32> = (0..768)
607 .map(|i| if i % 3 == 0 { 1.0 } else { 0.0 })
608 .collect();
609
610 let result = jaccard_similarity_fast(&a, &b);
611
612 assert!((0.0..=1.0).contains(&result), "Jaccard must be in [0,1]");
614 }
615
616 #[test]
617 fn test_jaccard_simd_aligned_vectors() {
618 let a: Vec<f32> = (0..64).map(|i| if i < 32 { 1.0 } else { 0.0 }).collect();
620 let b: Vec<f32> = (0..64).map(|i| if i < 48 { 1.0 } else { 0.0 }).collect();
621
622 let result = jaccard_similarity_fast(&a, &b);
623
624 let expected = 32.0 / 48.0;
626 assert!(
627 (result - expected).abs() < EPSILON,
628 "Expected {expected}, got {result}"
629 );
630 }
631
632 #[test]
633 fn test_jaccard_simd_unaligned_vectors() {
634 let a: Vec<f32> = (0..67).map(|i| if i < 30 { 1.0 } else { 0.0 }).collect();
636 let b: Vec<f32> = (0..67).map(|i| if i < 40 { 1.0 } else { 0.0 }).collect();
637
638 let result = jaccard_similarity_fast(&a, &b);
639
640 let expected = 30.0 / 40.0;
642 assert!(
643 (result - expected).abs() < EPSILON,
644 "Expected {expected}, got {result}"
645 );
646 }
647
648 #[test]
649 fn test_jaccard_consistency_scalar_vs_reference() {
650 for dim in [7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 768] {
652 let a: Vec<f32> = (0..dim)
653 .map(|i| if (i * 7) % 11 < 6 { 1.0 } else { 0.0 })
654 .collect();
655 let b: Vec<f32> = (0..dim)
656 .map(|i| if (i * 5) % 9 < 5 { 1.0 } else { 0.0 })
657 .collect();
658
659 let result = jaccard_similarity_fast(&a, &b);
660
661 let mut intersection = 0u32;
663 let mut union = 0u32;
664 for i in 0..dim {
665 let in_a = a[i] > 0.5;
666 let in_b = b[i] > 0.5;
667 if in_a && in_b {
668 intersection += 1;
669 }
670 if in_a || in_b {
671 union += 1;
672 }
673 }
674 let expected = if union == 0 {
675 1.0
676 } else {
677 intersection as f32 / union as f32
678 };
679
680 assert!(
681 (result - expected).abs() < EPSILON,
682 "Dim {dim}: expected {expected}, got {result}"
683 );
684 }
685 }
686
687 #[test]
692 fn test_calculate_prefetch_distance_small_vectors() {
693 assert_eq!(calculate_prefetch_distance(32), 4);
695 assert_eq!(calculate_prefetch_distance(64), 4);
697 }
698
699 #[test]
700 fn test_calculate_prefetch_distance_medium_vectors() {
701 assert_eq!(calculate_prefetch_distance(128), 8);
703 assert_eq!(calculate_prefetch_distance(256), 16);
705 }
706
707 #[test]
708 fn test_calculate_prefetch_distance_large_vectors() {
709 assert_eq!(calculate_prefetch_distance(768), 16);
711 assert_eq!(calculate_prefetch_distance(1536), 16);
713 }
714
715 #[test]
716 fn test_prefetch_vector_does_not_panic() {
717 let empty: Vec<f32> = vec![];
719 prefetch_vector(&empty); let small = vec![1.0, 2.0, 3.0];
722 prefetch_vector(&small); let large = generate_test_vector(768, 0.0);
725 prefetch_vector(&large); }
727
728 #[test]
729 fn test_l2_cache_line_constant() {
730 assert_eq!(L2_CACHE_LINE_BYTES, 64);
732 }
733}