1use alloc::collections::BinaryHeap;
2use alloc::vec;
3use alloc::vec::Vec;
4use core::fmt::{self, Debug};
5
6use crate::vector::SQVec;
7
8#[cfg(all(target_arch = "x86_64", feature = "std"))]
9mod simd_x86;
10
11#[inline]
17fn sqrt_f32(x: f32) -> f32 {
18 #[cfg(feature = "std")]
19 {
20 x.sqrt()
21 }
22 #[cfg(not(feature = "std"))]
23 {
24 if x < 0.0 || x.is_nan() {
25 return f32::NAN;
26 }
27 if x == 0.0 || x.is_infinite() {
28 return x;
29 }
30 if x < f32::MIN_POSITIVE {
34 return sqrt_f32(x * 16_777_216.0) / 4096.0;
36 }
37 let bits = x.to_bits();
39 let guess_bits = (bits >> 1) + 0x1FC0_0000;
40 let mut guess = f32::from_bits(guess_bits);
41 guess = 0.5 * (guess + x / guess);
43 guess = 0.5 * (guess + x / guess);
44 guess = 0.5 * (guess + x / guess);
45 guess = 0.5 * (guess + x / guess);
46 guess = 0.5 * (guess + x / guess);
47 guess
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum DistanceMetric {
77 Cosine,
79 EuclideanSq,
81 DotProduct,
84 Manhattan,
86}
87
88impl DistanceMetric {
89 #[inline]
103 pub fn compute(&self, a: &[f32], b: &[f32]) -> f32 {
104 if a.len() != b.len() {
105 return f32::MAX;
106 }
107 let d = match self {
108 Self::Cosine => cosine_distance(a, b),
109 Self::EuclideanSq => euclidean_distance_sq(a, b),
110 Self::DotProduct => -dot_product(a, b),
111 Self::Manhattan => manhattan_distance(a, b),
112 };
113 if d.is_nan() { f32::MAX } else { d }
114 }
115}
116
117impl fmt::Display for DistanceMetric {
118 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119 match self {
120 Self::Cosine => f.write_str("cosine"),
121 Self::EuclideanSq => f.write_str("euclidean_sq"),
122 Self::DotProduct => f.write_str("dot_product"),
123 Self::Manhattan => f.write_str("manhattan"),
124 }
125 }
126}
127
128#[inline]
139pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
140 assert_eq!(a.len(), b.len(), "dot_product: dimension mismatch");
141 #[cfg(all(target_arch = "x86_64", feature = "std"))]
142 {
143 if is_x86_feature_detected!("avx2") {
144 return unsafe { simd_x86::dot_product_avx2(a, b) };
146 }
147 }
148 dot_product_scalar(a, b)
149}
150
151#[inline]
152fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
153 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
154}
155
156#[inline]
166pub fn euclidean_distance_sq(a: &[f32], b: &[f32]) -> f32 {
167 assert_eq!(
168 a.len(),
169 b.len(),
170 "euclidean_distance_sq: dimension mismatch"
171 );
172 #[cfg(all(target_arch = "x86_64", feature = "std"))]
173 {
174 if is_x86_feature_detected!("avx2") {
175 return unsafe { simd_x86::euclidean_distance_sq_avx2(a, b) };
177 }
178 }
179 euclidean_distance_sq_scalar(a, b)
180}
181
182#[inline]
183fn euclidean_distance_sq_scalar(a: &[f32], b: &[f32]) -> f32 {
184 a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
185}
186
187#[inline]
198pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
199 assert_eq!(a.len(), b.len(), "cosine_similarity: dimension mismatch");
200 #[cfg(all(target_arch = "x86_64", feature = "std"))]
201 {
202 if is_x86_feature_detected!("avx2") {
203 return unsafe { simd_x86::cosine_similarity_avx2(a, b) };
205 }
206 }
207 cosine_similarity_scalar(a, b)
208}
209
210#[inline]
211fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
212 let mut dot = 0.0f32;
213 let mut norm_a = 0.0f32;
214 let mut norm_b = 0.0f32;
215 for i in 0..a.len() {
216 let x = a[i];
217 let y = b[i];
218 dot += x * y;
219 norm_a += x * x;
220 norm_b += y * y;
221 }
222 let denom = sqrt_f32(norm_a) * sqrt_f32(norm_b);
223 if denom == 0.0 {
224 0.0
225 } else {
226 (dot / denom).clamp(-1.0, 1.0)
227 }
228}
229
230#[inline]
241pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
242 1.0 - cosine_similarity(a, b)
243}
244
245#[inline]
253pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
254 assert_eq!(a.len(), b.len(), "manhattan_distance: dimension mismatch");
255 #[cfg(all(target_arch = "x86_64", feature = "std"))]
256 {
257 if is_x86_feature_detected!("avx2") {
258 return unsafe { simd_x86::manhattan_distance_avx2(a, b) };
260 }
261 }
262 manhattan_distance_scalar(a, b)
263}
264
265#[inline]
266fn manhattan_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
267 a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
268}
269
270#[inline]
277pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
278 let len = a.len().min(b.len());
279 let a = &a[..len];
280 let b = &b[..len];
281 #[cfg(all(target_arch = "x86_64", feature = "std"))]
282 {
283 if is_x86_feature_detected!("avx2") {
284 return unsafe { simd_x86::hamming_distance_avx2(a, b) };
286 }
287 }
288 hamming_distance_scalar(a, b)
289}
290
291#[inline]
292fn hamming_distance_scalar(a: &[u8], b: &[u8]) -> u32 {
293 a.iter()
294 .zip(b.iter())
295 .map(|(x, y)| (x ^ y).count_ones())
296 .sum()
297}
298
299#[inline]
307pub fn l2_norm(v: &[f32]) -> f32 {
308 sqrt_f32(v.iter().map(|x| x * x).sum::<f32>())
309}
310
311#[inline]
321pub fn l2_normalize(v: &mut [f32]) {
322 let norm = l2_norm(v);
323 if norm.is_finite() && norm > 0.0 {
324 let inv = 1.0 / norm;
325 for x in v.iter_mut() {
326 *x *= inv;
327 }
328 } else if !norm.is_finite() {
329 let max_abs = v.iter().fold(0.0f32, |acc, &x| {
332 let a = x.abs();
333 if a > acc { a } else { acc }
334 });
335 if max_abs == 0.0 || !max_abs.is_finite() {
336 return;
337 }
338 let inv_max = 1.0 / max_abs;
339 for x in v.iter_mut() {
340 *x *= inv_max;
341 }
342 let scaled_norm = l2_norm(v);
343 if scaled_norm.is_finite() && scaled_norm > 0.0 {
344 let inv = 1.0 / scaled_norm;
345 for x in v.iter_mut() {
346 *x *= inv;
347 }
348 }
349 }
350}
351
352#[inline]
356pub fn l2_normalized(v: &[f32]) -> Vec<f32> {
357 let mut out = v.to_vec();
358 l2_normalize(&mut out);
359 out
360}
361
362pub fn quantize_binary(v: &[f32]) -> Vec<u8> {
384 let byte_count = v.len().div_ceil(8);
385 let mut result = vec![0u8; byte_count];
386 for (i, &val) in v.iter().enumerate() {
387 if val > 0.0 {
388 let byte_idx = i / 8;
389 let bit_idx = 7 - (i % 8); result[byte_idx] |= 1 << bit_idx;
391 }
392 }
393 result
394}
395
396pub fn quantize_scalar<const N: usize>(v: &[f32; N]) -> SQVec<N> {
414 let mut min_val = f32::INFINITY;
415 let mut max_val = f32::NEG_INFINITY;
416 for &x in v {
417 if x < min_val {
418 min_val = x;
419 }
420 if x > max_val {
421 max_val = x;
422 }
423 }
424
425 let mut codes = [0u8; N];
426 let range = max_val - min_val;
427 if !range.is_finite() {
428 return SQVec {
431 min_val: 0.0,
432 max_val: 0.0,
433 codes,
434 };
435 }
436 if range >= f32::MIN_POSITIVE {
437 let inv_range = 255.0 / range;
438 if inv_range.is_finite() {
439 for (i, &x) in v.iter().enumerate() {
440 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
443 let q = ((x - min_val) * inv_range + 0.5) as u8;
444 codes[i] = q;
445 }
446 }
447 }
449 SQVec {
452 min_val,
453 max_val,
454 codes,
455 }
456}
457
458#[inline]
462pub fn dequantize_scalar<const N: usize>(sq: &SQVec<N>) -> [f32; N] {
463 sq.dequantize()
464}
465
466#[inline]
475pub fn sq_euclidean_distance_sq<const N: usize>(query: &[f32; N], sq: &SQVec<N>) -> f32 {
476 let range = sq.max_val - sq.min_val;
477 if !range.is_finite() {
478 return f32::MAX;
479 }
480 if range == 0.0 {
481 let d: f32 = query
483 .iter()
484 .map(|&q| {
485 let diff = q - sq.min_val;
486 diff * diff
487 })
488 .sum();
489 return if d.is_nan() { f32::MAX } else { d };
490 }
491 let scale = range / 255.0;
492 let mut sum = 0.0f32;
493 for (i, &q) in query.iter().enumerate() {
494 let dequant = sq.min_val + f32::from(sq.codes[i]) * scale;
495 let diff = q - dequant;
496 sum += diff * diff;
497 }
498 if sum.is_nan() { f32::MAX } else { sum }
499}
500
501#[inline]
508pub fn sq_dot_product<const N: usize>(query: &[f32; N], sq: &SQVec<N>) -> f32 {
509 let range = sq.max_val - sq.min_val;
510 if !range.is_finite() {
511 return 0.0;
512 }
513 if range == 0.0 {
514 let d = query.iter().sum::<f32>() * sq.min_val;
515 return if d.is_nan() { 0.0 } else { d };
516 }
517 let scale = range / 255.0;
518 let mut sum = 0.0f32;
519 for (i, &q) in query.iter().enumerate() {
520 let dequant = sq.min_val + f32::from(sq.codes[i]) * scale;
521 sum += q * dequant;
522 }
523 if sum.is_nan() { 0.0 } else { sum }
524}
525
526#[derive(Debug, Clone)]
532pub struct Neighbor<K> {
533 pub key: K,
535 pub distance: f32,
537}
538
539impl<K> PartialEq for Neighbor<K> {
540 fn eq(&self, other: &Self) -> bool {
541 self.distance.to_bits() == other.distance.to_bits()
543 }
544}
545
546impl<K> Eq for Neighbor<K> {}
547
548impl<K> PartialOrd for Neighbor<K> {
549 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
550 Some(self.cmp(other))
551 }
552}
553
554impl<K> Ord for Neighbor<K> {
555 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
556 self.distance.total_cmp(&other.distance)
560 }
561}
562
563pub fn nearest_k<K, I, F>(iter: I, query: &[f32], k: usize, distance_fn: F) -> Vec<Neighbor<K>>
594where
595 I: Iterator<Item = (K, Vec<f32>)>,
596 F: Fn(&[f32], &[f32]) -> f32,
597{
598 if k == 0 {
599 return Vec::new();
600 }
601
602 let mut heap: BinaryHeap<Neighbor<K>> = BinaryHeap::with_capacity(k + 1);
605
606 for (key, vec) in iter {
607 let dist = distance_fn(query, &vec);
608 if heap.len() < k {
609 heap.push(Neighbor {
610 key,
611 distance: dist,
612 });
613 } else if heap
614 .peek()
615 .is_some_and(|worst| dist.total_cmp(&worst.distance).is_lt())
616 {
617 heap.pop();
618 heap.push(Neighbor {
619 key,
620 distance: dist,
621 });
622 }
623 }
624
625 let mut results: Vec<Neighbor<K>> = heap.into_vec();
626 results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
627 results
628}
629
630pub fn nearest_k_fixed<K, I, F, const N: usize>(
635 iter: I,
636 query: &[f32; N],
637 k: usize,
638 distance_fn: F,
639) -> Vec<Neighbor<K>>
640where
641 I: Iterator<Item = (K, [f32; N])>,
642 F: Fn(&[f32], &[f32]) -> f32,
643{
644 if k == 0 {
645 return Vec::new();
646 }
647
648 let mut heap: BinaryHeap<Neighbor<K>> = BinaryHeap::with_capacity(k + 1);
649
650 for (key, vec) in iter {
651 let dist = distance_fn(query.as_slice(), vec.as_slice());
652 if heap.len() < k {
653 heap.push(Neighbor {
654 key,
655 distance: dist,
656 });
657 } else if heap
658 .peek()
659 .is_some_and(|worst| dist.total_cmp(&worst.distance).is_lt())
660 {
661 heap.pop();
662 heap.push(Neighbor {
663 key,
664 distance: dist,
665 });
666 }
667 }
668
669 let mut results: Vec<Neighbor<K>> = heap.into_vec();
670 results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
671
672 results
673}
674
675#[inline]
687pub fn write_f32_le(dest: &mut [u8], values: &[f32]) {
688 let count = (dest.len() / 4).min(values.len());
689 #[cfg(target_endian = "little")]
690 {
691 let byte_len = count * 4;
692 unsafe {
695 core::ptr::copy_nonoverlapping(
696 values.as_ptr().cast::<u8>(),
697 dest.as_mut_ptr(),
698 byte_len,
699 );
700 }
701 }
702 #[cfg(not(target_endian = "little"))]
703 {
704 for (i, val) in values.iter().enumerate().take(count) {
705 let start = i * 4;
706 dest[start..start + 4].copy_from_slice(&val.to_le_bytes());
707 }
708 }
709}
710
711#[inline]
715pub fn read_f32_le(src: &[u8]) -> Vec<f32> {
716 let usable = src.len() - (src.len() % 4);
717 let count = usable / 4;
718 #[cfg(target_endian = "little")]
719 {
720 let mut result = vec![0.0f32; count];
721 unsafe {
724 core::ptr::copy_nonoverlapping(src.as_ptr(), result.as_mut_ptr().cast::<u8>(), usable);
725 }
726 result
727 }
728 #[cfg(not(target_endian = "little"))]
729 {
730 let mut result = Vec::with_capacity(count);
731 for i in 0..count {
732 let start = i * 4;
733 let bytes: [u8; 4] = src[start..start + 4].try_into().unwrap_or([0u8; 4]);
735 result.push(f32::from_le_bytes(bytes));
736 }
737 result
738 }
739}
740
741#[cfg(test)]
742#[allow(
743 clippy::float_cmp,
744 clippy::cast_precision_loss,
745 clippy::cast_possible_truncation
746)]
747mod tests {
748 use super::*;
749
750 const DIMS: &[usize] = &[1, 7, 8, 15, 16, 31, 32, 128, 384, 768];
753
754 fn make_vecs(dim: usize) -> (Vec<f32>, Vec<f32>) {
755 let a: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.1 - 5.0).collect();
756 let b: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.2 + 1.0).collect();
757 (a, b)
758 }
759
760 fn assert_close(actual: f32, expected: f32, tol: f32, label: &str, dim: usize) {
761 let diff = (actual - expected).abs();
762 let scale = expected.abs().max(1.0);
763 assert!(
764 diff < tol * scale,
765 "{label} dim={dim}: expected={expected}, actual={actual}, diff={diff}"
766 );
767 }
768
769 #[test]
770 fn dot_product_matches_scalar() {
771 for &dim in DIMS {
772 let (a, b) = make_vecs(dim);
773 let scalar = dot_product_scalar(&a, &b);
774 let result = dot_product(&a, &b);
775 assert_close(result, scalar, 1e-5, "dot_product", dim);
776 }
777 }
778
779 #[test]
780 fn euclidean_distance_sq_matches_scalar() {
781 for &dim in DIMS {
782 let (a, b) = make_vecs(dim);
783 let scalar = euclidean_distance_sq_scalar(&a, &b);
784 let result = euclidean_distance_sq(&a, &b);
785 assert_close(result, scalar, 1e-5, "euclidean_distance_sq", dim);
786 }
787 }
788
789 #[test]
790 fn cosine_similarity_matches_scalar() {
791 for &dim in DIMS {
792 let (a, b) = make_vecs(dim);
793 let scalar = cosine_similarity_scalar(&a, &b);
794 let result = cosine_similarity(&a, &b);
795 assert_close(result, scalar, 1e-5, "cosine_similarity", dim);
796 }
797 }
798
799 #[test]
800 fn manhattan_distance_matches_scalar() {
801 for &dim in DIMS {
802 let (a, b) = make_vecs(dim);
803 let scalar = manhattan_distance_scalar(&a, &b);
804 let result = manhattan_distance(&a, &b);
805 assert_close(result, scalar, 1e-5, "manhattan_distance", dim);
806 }
807 }
808
809 #[test]
810 fn hamming_distance_matches_scalar() {
811 for dim in [1usize, 7, 8, 15, 16, 31, 32, 64, 128, 256] {
812 let a: Vec<u8> = (0..dim).map(|i| (i * 37 + 13) as u8).collect();
813 let b: Vec<u8> = (0..dim).map(|i| (i * 53 + 7) as u8).collect();
814 let scalar = hamming_distance_scalar(&a, &b);
815 let result = hamming_distance(&a, &b);
816 assert_eq!(
817 result, scalar,
818 "hamming_distance dim={dim}: scalar={scalar}, simd={result}"
819 );
820 }
821 }
822
823 #[test]
824 fn dot_product_zero_vectors() {
825 let a = vec![0.0f32; 128];
826 let b = vec![0.0f32; 128];
827 assert_eq!(dot_product(&a, &b), 0.0);
828 }
829
830 #[test]
831 fn cosine_similarity_zero_vector() {
832 let a = vec![0.0f32; 32];
833 let b = vec![1.0f32; 32];
834 assert_eq!(cosine_similarity(&a, &b), 0.0);
835 }
836
837 #[test]
838 fn cosine_similarity_identical() {
839 let a: Vec<f32> = (0..64).map(|i| (i as f32) * 0.3 + 0.1).collect();
840 let result = cosine_similarity(&a, &a);
841 assert!(
842 (result - 1.0).abs() < 1e-6,
843 "identical vectors: sim={result}"
844 );
845 }
846
847 #[test]
848 fn cosine_similarity_opposite() {
849 let a: Vec<f32> = (0..64).map(|i| (i as f32) * 0.3 + 0.1).collect();
850 let b: Vec<f32> = a.iter().map(|x| -x).collect();
851 let result = cosine_similarity(&a, &b);
852 assert!(
853 (result - (-1.0)).abs() < 1e-6,
854 "opposite vectors: sim={result}"
855 );
856 }
857
858 #[test]
859 fn hamming_distance_known_pattern() {
860 let a = vec![0xFF_u8; 32];
862 let b = vec![0x00_u8; 32];
863 assert_eq!(hamming_distance(&a, &b), 32 * 8);
864 }
865
866 #[test]
867 fn hamming_distance_identical() {
868 let a = vec![0xAB_u8; 64];
869 assert_eq!(hamming_distance(&a, &a), 0);
870 }
871
872 #[test]
873 fn euclidean_distance_sq_identical() {
874 let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
875 assert_eq!(euclidean_distance_sq(&a, &a), 0.0);
876 }
877
878 #[test]
879 fn manhattan_distance_identical() {
880 let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
881 assert_eq!(manhattan_distance(&a, &a), 0.0);
882 }
883
884 #[test]
885 #[should_panic(expected = "dimension mismatch")]
886 fn dot_product_dimension_mismatch_panics() {
887 let a = vec![1.0f32; 10];
888 let b = vec![1.0f32; 11];
889 dot_product(&a, &b);
890 }
891
892 #[test]
893 #[should_panic(expected = "dimension mismatch")]
894 fn euclidean_dimension_mismatch_panics() {
895 let a = vec![1.0f32; 10];
896 let b = vec![1.0f32; 11];
897 euclidean_distance_sq(&a, &b);
898 }
899
900 #[test]
901 fn distance_metric_nan_returns_max() {
902 let a = [1.0f32, f32::NAN, 3.0];
903 let b = [4.0f32, 5.0, 6.0];
904 let d = DistanceMetric::EuclideanSq.compute(&a, &b);
905 assert_eq!(d, f32::MAX);
906 }
907
908 #[test]
909 fn distance_metric_mismatch_returns_max() {
910 let a = [1.0f32, 2.0];
911 let b = [1.0f32, 2.0, 3.0];
912 let d = DistanceMetric::Cosine.compute(&a, &b);
913 assert_eq!(d, f32::MAX);
914 }
915
916 #[test]
917 fn write_read_f32_le_roundtrip() {
918 let values: Vec<f32> = (0..100).map(|i| (i as f32) * 0.123 - 6.0).collect();
919 let mut buf = vec![0u8; values.len() * 4];
920 write_f32_le(&mut buf, &values);
921 let decoded = read_f32_le(&buf);
922 assert_eq!(decoded, values);
923 }
924}