1use ndarray::{Array1, ArrayView1};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5#[cfg_attr(
6 feature = "persistence",
7 derive(serde::Serialize, serde::Deserialize)
8)]
9pub enum DistanceMetric {
10 Cosine,
12 Euclidean,
14 DotProduct,
16}
17
18impl DistanceMetric {
19 pub fn compute(&self, a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
21 match self {
22 DistanceMetric::Cosine => cosine_distance(a, b),
23 DistanceMetric::Euclidean => euclidean_distance(a, b),
24 DistanceMetric::DotProduct => -dot_product(a, b),
25 }
26 }
27}
28
29pub fn cosine_distance(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
31 let dot = a.dot(b);
32 let norm_a = a.dot(a).sqrt();
33 let norm_b = b.dot(b).sqrt();
34 let denom = norm_a * norm_b;
35 if denom < f32::EPSILON {
36 return 1.0;
37 }
38 1.0 - (dot / denom)
39}
40
41pub fn euclidean_distance(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
43 a.iter()
44 .zip(b.iter())
45 .map(|(x, y)| {
46 let d = x - y;
47 d * d
48 })
49 .sum::<f32>()
50 .sqrt()
51}
52
53pub fn dot_product(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
55 a.dot(b)
56}
57
58pub fn normalize(v: &mut Array1<f32>) {
60 let norm = v.dot(v).sqrt();
61 if norm > f32::EPSILON {
62 *v /= norm;
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69 use ndarray::array;
70
71 #[test]
72 fn test_cosine_identical() {
73 let a = array![1.0, 0.0, 0.0];
74 let b = array![1.0, 0.0, 0.0];
75 let d = cosine_distance(&a.view(), &b.view());
76 assert!((d - 0.0).abs() < 1e-6);
77 }
78
79 #[test]
80 fn test_cosine_orthogonal() {
81 let a = array![1.0, 0.0];
82 let b = array![0.0, 1.0];
83 let d = cosine_distance(&a.view(), &b.view());
84 assert!((d - 1.0).abs() < 1e-6);
85 }
86
87 #[test]
88 fn test_euclidean() {
89 let a = array![0.0, 0.0];
90 let b = array![3.0, 4.0];
91 let d = euclidean_distance(&a.view(), &b.view());
92 assert!((d - 5.0).abs() < 1e-6);
93 }
94
95 #[test]
96 fn test_normalize() {
97 let mut v = array![3.0, 4.0];
98 normalize(&mut v);
99 let norm = v.dot(&v).sqrt();
100 assert!((norm - 1.0).abs() < 1e-6);
101 }
102
103 #[test]
104 fn test_normalize_zero_vector() {
105 let mut v = array![0.0, 0.0, 0.0];
106 normalize(&mut v);
107 assert_eq!(v, array![0.0, 0.0, 0.0]);
108 }
109}