1use ndarray::{Array1, ArrayView1};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5#[cfg_attr(
6 feature = "persistence",
7 derive(serde::Serialize, serde::Deserialize)
8)]
9pub enum DistanceMetric {
10 Cosine,
12 Euclidean,
14 DotProduct,
16}
17
18impl DistanceMetric {
19 pub fn compute(&self, a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
21 match self {
22 DistanceMetric::Cosine => cosine_distance(a, b),
23 DistanceMetric::Euclidean => euclidean_distance(a, b),
24 DistanceMetric::DotProduct => -dot_product(a, b),
25 }
26 }
27}
28
29pub fn cosine_distance(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
31 let dot = a.dot(b);
32 let norm_a = a.dot(a).sqrt();
33 let norm_b = b.dot(b).sqrt();
34 let denom = norm_a * norm_b;
35 if denom < f32::EPSILON {
36 return 1.0;
37 }
38 1.0 - (dot / denom)
39}
40
41#[inline]
46pub fn cosine_distance_normalized(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
47 1.0 - a.dot(b)
48}
49
50pub fn euclidean_distance(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
55 let norm_a_sq = a.dot(a);
56 let norm_b_sq = b.dot(b);
57 let dot_ab = a.dot(b);
58 (norm_a_sq + norm_b_sq - 2.0 * dot_ab).max(0.0).sqrt()
60}
61
62pub fn dot_product(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
64 a.dot(b)
65}
66
67pub fn normalize(v: &mut Array1<f32>) {
69 let norm = v.dot(v).sqrt();
70 if norm > f32::EPSILON {
71 *v /= norm;
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use ndarray::array;
79
80 #[test]
81 fn test_cosine_identical() {
82 let a = array![1.0, 0.0, 0.0];
83 let b = array![1.0, 0.0, 0.0];
84 let d = cosine_distance(&a.view(), &b.view());
85 assert!((d - 0.0).abs() < 1e-6);
86 }
87
88 #[test]
89 fn test_cosine_orthogonal() {
90 let a = array![1.0, 0.0];
91 let b = array![0.0, 1.0];
92 let d = cosine_distance(&a.view(), &b.view());
93 assert!((d - 1.0).abs() < 1e-6);
94 }
95
96 #[test]
97 fn test_euclidean() {
98 let a = array![0.0, 0.0];
99 let b = array![3.0, 4.0];
100 let d = euclidean_distance(&a.view(), &b.view());
101 assert!((d - 5.0).abs() < 1e-6);
102 }
103
104 #[test]
105 fn test_normalize() {
106 let mut v = array![3.0, 4.0];
107 normalize(&mut v);
108 let norm = v.dot(&v).sqrt();
109 assert!((norm - 1.0).abs() < 1e-6);
110 }
111
112 #[test]
113 fn test_normalize_zero_vector() {
114 let mut v = array![0.0, 0.0, 0.0];
115 normalize(&mut v);
116 assert_eq!(v, array![0.0, 0.0, 0.0]);
117 }
118}