Skip to main content

sql_rs/vector/
similarity.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
4pub enum DistanceMetric {
5    Cosine,
6    Euclidean,
7    DotProduct,
8}
9
10pub trait Distance {
11    fn distance(&self, other: &[f32], metric: DistanceMetric) -> f32;
12}
13
14impl Distance for Vec<f32> {
15    fn distance(&self, other: &[f32], metric: DistanceMetric) -> f32 {
16        match metric {
17            DistanceMetric::Cosine => cosine_distance(self, other),
18            DistanceMetric::Euclidean => euclidean_distance(self, other),
19            DistanceMetric::DotProduct => dot_product_distance(self, other),
20        }
21    }
22}
23
24impl Distance for [f32] {
25    fn distance(&self, other: &[f32], metric: DistanceMetric) -> f32 {
26        match metric {
27            DistanceMetric::Cosine => cosine_distance(self, other),
28            DistanceMetric::Euclidean => euclidean_distance(self, other),
29            DistanceMetric::DotProduct => dot_product_distance(self, other),
30        }
31    }
32}
33
34fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
35    let dot = dot_product(a, b);
36    let norm_a = magnitude(a);
37    let norm_b = magnitude(b);
38    
39    if norm_a == 0.0 || norm_b == 0.0 {
40        return 1.0;
41    }
42    
43    1.0 - (dot / (norm_a * norm_b))
44}
45
46fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
47    a.iter()
48        .zip(b.iter())
49        .map(|(x, y)| (x - y).powi(2))
50        .sum::<f32>()
51        .sqrt()
52}
53
54fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
55    -dot_product(a, b)
56}
57
58fn dot_product(a: &[f32], b: &[f32]) -> f32 {
59    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
60}
61
62fn magnitude(v: &[f32]) -> f32 {
63    v.iter().map(|x| x * x).sum::<f32>().sqrt()
64}
65
66#[allow(dead_code)]
67pub fn normalize(v: &mut [f32]) {
68    let mag = magnitude(v);
69    if mag > 0.0 {
70        for x in v.iter_mut() {
71            *x /= mag;
72        }
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn test_cosine_distance() {
82        let a = vec![1.0, 0.0, 0.0];
83        let b = vec![0.0, 1.0, 0.0];
84        assert!((cosine_distance(&a, &b) - 1.0).abs() < 1e-6);
85        
86        let c = vec![1.0, 0.0, 0.0];
87        let d = vec![1.0, 0.0, 0.0];
88        assert!(cosine_distance(&c, &d).abs() < 1e-6);
89    }
90
91    #[test]
92    fn test_euclidean_distance() {
93        let a = vec![0.0, 0.0];
94        let b = vec![3.0, 4.0];
95        assert!((euclidean_distance(&a, &b) - 5.0).abs() < 1e-6);
96    }
97
98    #[test]
99    fn test_normalize() {
100        let mut v = vec![3.0, 4.0];
101        normalize(&mut v);
102        assert!((magnitude(&v) - 1.0).abs() < 1e-6);
103    }
104}