Skip to main content

superbit/
distance.rs

1use ndarray::{Array1, ArrayView1};
2
3/// Distance metric used for nearest-neighbor comparisons.
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5#[cfg_attr(
6    feature = "persistence",
7    derive(serde::Serialize, serde::Deserialize)
8)]
9pub enum DistanceMetric {
10    /// Cosine distance: 1 - cos(a, b). Range [0, 2]. 0 = identical direction.
11    Cosine,
12    /// Euclidean (L2) distance. Range [0, inf).
13    Euclidean,
14    /// Negative dot product (so smaller = more similar). Range (-inf, inf).
15    DotProduct,
16}
17
18impl DistanceMetric {
19    /// Compute the distance between two vectors using this metric.
20    pub fn compute(&self, a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
21        match self {
22            DistanceMetric::Cosine => cosine_distance(a, b),
23            DistanceMetric::Euclidean => euclidean_distance(a, b),
24            DistanceMetric::DotProduct => -dot_product(a, b),
25        }
26    }
27}
28
29/// Cosine distance: 1 - cos(a, b).
30pub fn cosine_distance(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
31    let dot = a.dot(b);
32    let norm_a = a.dot(a).sqrt();
33    let norm_b = b.dot(b).sqrt();
34    let denom = norm_a * norm_b;
35    if denom < f32::EPSILON {
36        return 1.0;
37    }
38    1.0 - (dot / denom)
39}
40
41/// Fast cosine distance for pre-normalized vectors: just 1 - dot(a, b).
42///
43/// Both `a` and `b` must already have unit L2 norm. Skips the two extra
44/// dot products that `cosine_distance` needs for norm computation.
45#[inline]
46pub fn cosine_distance_normalized(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
47    1.0 - a.dot(b)
48}
49
50/// Euclidean (L2) distance between two vectors.
51///
52/// Uses the identity ||a - b||^2 = ||a||^2 + ||b||^2 - 2*dot(a, b)
53/// to leverage ndarray's optimized dot product instead of a scalar loop.
54pub fn euclidean_distance(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
55    let norm_a_sq = a.dot(a);
56    let norm_b_sq = b.dot(b);
57    let dot_ab = a.dot(b);
58    // Clamp to avoid sqrt of small negative due to floating-point error.
59    (norm_a_sq + norm_b_sq - 2.0 * dot_ab).max(0.0).sqrt()
60}
61
62/// Dot product of two vectors.
63pub fn dot_product(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
64    a.dot(b)
65}
66
67/// Normalize a vector to unit length (L2 norm). Leaves zero vectors unchanged.
68pub fn normalize(v: &mut Array1<f32>) {
69    let norm = v.dot(v).sqrt();
70    if norm > f32::EPSILON {
71        *v /= norm;
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use ndarray::array;
79
80    #[test]
81    fn test_cosine_identical() {
82        let a = array![1.0, 0.0, 0.0];
83        let b = array![1.0, 0.0, 0.0];
84        let d = cosine_distance(&a.view(), &b.view());
85        assert!((d - 0.0).abs() < 1e-6);
86    }
87
88    #[test]
89    fn test_cosine_orthogonal() {
90        let a = array![1.0, 0.0];
91        let b = array![0.0, 1.0];
92        let d = cosine_distance(&a.view(), &b.view());
93        assert!((d - 1.0).abs() < 1e-6);
94    }
95
96    #[test]
97    fn test_euclidean() {
98        let a = array![0.0, 0.0];
99        let b = array![3.0, 4.0];
100        let d = euclidean_distance(&a.view(), &b.view());
101        assert!((d - 5.0).abs() < 1e-6);
102    }
103
104    #[test]
105    fn test_normalize() {
106        let mut v = array![3.0, 4.0];
107        normalize(&mut v);
108        let norm = v.dot(&v).sqrt();
109        assert!((norm - 1.0).abs() < 1e-6);
110    }
111
112    #[test]
113    fn test_normalize_zero_vector() {
114        let mut v = array![0.0, 0.0, 0.0];
115        normalize(&mut v);
116        assert_eq!(v, array![0.0, 0.0, 0.0]);
117    }
118}