sql_rs/vector/
similarity.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
4pub enum DistanceMetric {
5 Cosine,
6 Euclidean,
7 DotProduct,
8}
9
10pub trait Distance {
11 fn distance(&self, other: &[f32], metric: DistanceMetric) -> f32;
12}
13
14impl Distance for Vec<f32> {
15 fn distance(&self, other: &[f32], metric: DistanceMetric) -> f32 {
16 match metric {
17 DistanceMetric::Cosine => cosine_distance(self, other),
18 DistanceMetric::Euclidean => euclidean_distance(self, other),
19 DistanceMetric::DotProduct => dot_product_distance(self, other),
20 }
21 }
22}
23
24impl Distance for [f32] {
25 fn distance(&self, other: &[f32], metric: DistanceMetric) -> f32 {
26 match metric {
27 DistanceMetric::Cosine => cosine_distance(self, other),
28 DistanceMetric::Euclidean => euclidean_distance(self, other),
29 DistanceMetric::DotProduct => dot_product_distance(self, other),
30 }
31 }
32}
33
34fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
35 let dot = dot_product(a, b);
36 let norm_a = magnitude(a);
37 let norm_b = magnitude(b);
38
39 if norm_a == 0.0 || norm_b == 0.0 {
40 return 1.0;
41 }
42
43 1.0 - (dot / (norm_a * norm_b))
44}
45
46fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
47 a.iter()
48 .zip(b.iter())
49 .map(|(x, y)| (x - y).powi(2))
50 .sum::<f32>()
51 .sqrt()
52}
53
54fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
55 -dot_product(a, b)
56}
57
58fn dot_product(a: &[f32], b: &[f32]) -> f32 {
59 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
60}
61
62fn magnitude(v: &[f32]) -> f32 {
63 v.iter().map(|x| x * x).sum::<f32>().sqrt()
64}
65
66#[allow(dead_code)]
67pub fn normalize(v: &mut [f32]) {
68 let mag = magnitude(v);
69 if mag > 0.0 {
70 for x in v.iter_mut() {
71 *x /= mag;
72 }
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79
80 #[test]
81 fn test_cosine_distance() {
82 let a = vec![1.0, 0.0, 0.0];
83 let b = vec![0.0, 1.0, 0.0];
84 assert!((cosine_distance(&a, &b) - 1.0).abs() < 1e-6);
85
86 let c = vec![1.0, 0.0, 0.0];
87 let d = vec![1.0, 0.0, 0.0];
88 assert!(cosine_distance(&c, &d).abs() < 1e-6);
89 }
90
91 #[test]
92 fn test_euclidean_distance() {
93 let a = vec![0.0, 0.0];
94 let b = vec![3.0, 4.0];
95 assert!((euclidean_distance(&a, &b) - 5.0).abs() < 1e-6);
96 }
97
98 #[test]
99 fn test_normalize() {
100 let mut v = vec![3.0, 4.0];
101 normalize(&mut v);
102 assert!((magnitude(&v) - 1.0).abs() < 1e-6);
103 }
104}