sqlite_graphrag/
similarity.rs1pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
17 if a.len() != b.len() {
18 return 0.0;
19 }
20 let mut dot = 0.0_f32;
21 let mut norm_a = 0.0_f32;
22 let mut norm_b = 0.0_f32;
23 for i in 0..a.len() {
24 dot += a[i] * b[i];
25 norm_a += a[i] * a[i];
26 norm_b += b[i] * b[i];
27 }
28 let denom = (norm_a * norm_b).sqrt();
29 if denom == 0.0 {
30 0.0
31 } else {
32 dot / denom
33 }
34}
35
36pub fn similarity_to_distance(sim: f32) -> f32 {
41 1.0 - sim
42}
43
44pub fn top_k_by_score<I>(items: I, k: usize) -> Vec<(usize, f32)>
47where
48 I: IntoIterator<Item = f32>,
49{
50 let mut scored: Vec<(usize, f32)> = items.into_iter().enumerate().collect();
51 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
52 scored.truncate(k);
53 scored
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59
60 #[test]
61 fn identical_vectors_have_similarity_one() {
62 let v = vec![0.5, 0.5, 0.5, 0.5];
63 assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
64 }
65
66 #[test]
67 fn orthogonal_vectors_have_similarity_zero() {
68 let a = vec![1.0, 0.0];
69 let b = vec![0.0, 1.0];
70 assert!(cosine_similarity(&a, &b).abs() < 1e-6);
71 }
72
73 #[test]
74 fn opposite_vectors_have_similarity_minus_one() {
75 let a = vec![1.0, 0.0];
76 let b = vec![-1.0, 0.0];
77 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-6);
78 }
79
80 #[test]
81 fn zero_vector_returns_zero() {
82 let zero = vec![0.0, 0.0, 0.0];
83 let v = vec![1.0, 2.0, 3.0];
84 assert_eq!(cosine_similarity(&zero, &v), 0.0);
85 assert_eq!(cosine_similarity(&v, &zero), 0.0);
86 }
87
88 #[test]
89 fn mismatched_lengths_return_zero() {
90 let a = vec![1.0, 2.0];
91 let b = vec![1.0, 2.0, 3.0];
92 assert_eq!(cosine_similarity(&a, &b), 0.0);
93 }
94
95 #[test]
96 fn similarity_to_distance_inverts_correctly() {
97 assert!((similarity_to_distance(1.0) - 0.0).abs() < 1e-6);
98 assert!((similarity_to_distance(0.0) - 1.0).abs() < 1e-6);
99 assert!((similarity_to_distance(-1.0) - 2.0).abs() < 1e-6);
100 }
101
102 #[test]
103 fn top_k_returns_sorted_truncated() {
104 let items = vec![0.1, 0.9, 0.5, 0.3, 0.7];
105 let top = top_k_by_score(items, 3);
106 assert_eq!(top, vec![(1, 0.9), (4, 0.7), (2, 0.5)]);
107 }
108}