Skip to main content

superbit/
distance.rs

1use ndarray::{Array1, ArrayView1};
2
3/// Distance metric used for nearest-neighbor comparisons.
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5#[cfg_attr(
6    feature = "persistence",
7    derive(serde::Serialize, serde::Deserialize)
8)]
9pub enum DistanceMetric {
10    /// Cosine distance: 1 - cos(a, b). Range [0, 2]. 0 = identical direction.
11    Cosine,
12    /// Euclidean (L2) distance. Range [0, inf).
13    Euclidean,
14    /// Negative dot product (so smaller = more similar). Range (-inf, inf).
15    DotProduct,
16}
17
18impl DistanceMetric {
19    /// Compute the distance between two vectors using this metric.
20    pub fn compute(&self, a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
21        match self {
22            DistanceMetric::Cosine => cosine_distance(a, b),
23            DistanceMetric::Euclidean => euclidean_distance(a, b),
24            DistanceMetric::DotProduct => -dot_product(a, b),
25        }
26    }
27}
28
29/// Cosine distance: 1 - cos(a, b).
30pub fn cosine_distance(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
31    let dot = a.dot(b);
32    let norm_a = a.dot(a).sqrt();
33    let norm_b = b.dot(b).sqrt();
34    let denom = norm_a * norm_b;
35    if denom < f32::EPSILON {
36        return 1.0;
37    }
38    1.0 - (dot / denom)
39}
40
41/// Euclidean (L2) distance between two vectors.
42pub fn euclidean_distance(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
43    a.iter()
44        .zip(b.iter())
45        .map(|(x, y)| {
46            let d = x - y;
47            d * d
48        })
49        .sum::<f32>()
50        .sqrt()
51}
52
53/// Dot product of two vectors.
54pub fn dot_product(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
55    a.dot(b)
56}
57
58/// Normalize a vector to unit length (L2 norm). Leaves zero vectors unchanged.
59pub fn normalize(v: &mut Array1<f32>) {
60    let norm = v.dot(v).sqrt();
61    if norm > f32::EPSILON {
62        *v /= norm;
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69    use ndarray::array;
70
71    #[test]
72    fn test_cosine_identical() {
73        let a = array![1.0, 0.0, 0.0];
74        let b = array![1.0, 0.0, 0.0];
75        let d = cosine_distance(&a.view(), &b.view());
76        assert!((d - 0.0).abs() < 1e-6);
77    }
78
79    #[test]
80    fn test_cosine_orthogonal() {
81        let a = array![1.0, 0.0];
82        let b = array![0.0, 1.0];
83        let d = cosine_distance(&a.view(), &b.view());
84        assert!((d - 1.0).abs() < 1e-6);
85    }
86
87    #[test]
88    fn test_euclidean() {
89        let a = array![0.0, 0.0];
90        let b = array![3.0, 4.0];
91        let d = euclidean_distance(&a.view(), &b.view());
92        assert!((d - 5.0).abs() < 1e-6);
93    }
94
95    #[test]
96    fn test_normalize() {
97        let mut v = array![3.0, 4.0];
98        normalize(&mut v);
99        let norm = v.dot(&v).sqrt();
100        assert!((norm - 1.0).abs() < 1e-6);
101    }
102
103    #[test]
104    fn test_normalize_zero_vector() {
105        let mut v = array![0.0, 0.0, 0.0];
106        normalize(&mut v);
107        assert_eq!(v, array![0.0, 0.0, 0.0]);
108    }
109}