tuitbot_core/context/
semantic_search.rs1use super::semantic_index::SemanticIndex;
9
10#[derive(Debug, Clone)]
12pub struct SemanticHit {
13 pub chunk_id: i64,
15 pub distance: f32,
17 pub similarity: f32,
19}
20
21pub fn semantic_search(
26 index: &SemanticIndex,
27 query_embedding: &[f32],
28 limit: usize,
29) -> Vec<SemanticHit> {
30 let raw = index.search(query_embedding, limit);
31
32 raw.into_iter()
33 .map(|(chunk_id, distance)| SemanticHit {
34 chunk_id,
35 distance,
36 similarity: 1.0 - distance,
37 })
38 .collect()
39}
40
41#[cfg(test)]
42mod tests {
43 use super::*;
44
45 fn make_index(dim: usize) -> SemanticIndex {
46 SemanticIndex::new(dim, "test-model".to_string(), 1000)
47 }
48
49 #[test]
50 fn empty_index_returns_empty_results() {
51 let idx = make_index(3);
52 let hits = semantic_search(&idx, &[1.0, 0.0, 0.0], 5);
53 assert!(hits.is_empty());
54 }
55
56 #[test]
57 fn search_returns_sorted_by_similarity() {
58 let mut idx = make_index(3);
59 idx.insert(1, vec![1.0, 0.0, 0.0]).unwrap();
60 idx.insert(2, vec![0.9, 0.1, 0.0]).unwrap();
61 idx.insert(3, vec![0.0, 1.0, 0.0]).unwrap();
62
63 let hits = semantic_search(&idx, &[1.0, 0.0, 0.0], 3);
64 assert_eq!(hits.len(), 3);
65 assert_eq!(hits[0].chunk_id, 1);
67 assert!(hits[0].similarity > 0.99);
68 assert!(hits[0].similarity >= hits[1].similarity);
70 assert!(hits[1].similarity >= hits[2].similarity);
71 }
72
73 #[test]
74 fn limit_enforcement() {
75 let mut idx = make_index(2);
76 for i in 0..10 {
77 idx.insert(i, vec![i as f32, 1.0]).unwrap();
78 }
79
80 let hits = semantic_search(&idx, &[9.0, 1.0], 3);
81 assert_eq!(hits.len(), 3);
82 }
83
84 #[test]
85 fn zero_limit_returns_empty() {
86 let mut idx = make_index(2);
87 idx.insert(1, vec![1.0, 0.0]).unwrap();
88
89 let hits = semantic_search(&idx, &[1.0, 0.0], 0);
90 assert!(hits.is_empty());
91 }
92
93 #[test]
94 fn similarity_is_one_minus_distance() {
95 let mut idx = make_index(2);
96 idx.insert(1, vec![1.0, 0.0]).unwrap();
97
98 let hits = semantic_search(&idx, &[1.0, 0.0], 1);
99 assert_eq!(hits.len(), 1);
100 let h = &hits[0];
101 assert!((h.similarity - (1.0 - h.distance)).abs() < 1e-6);
102 }
103
104 #[test]
105 fn orthogonal_vectors_have_zero_similarity() {
106 let mut idx = make_index(2);
107 idx.insert(1, vec![0.0, 1.0]).unwrap();
108
109 let hits = semantic_search(&idx, &[1.0, 0.0], 1);
110 assert_eq!(hits.len(), 1);
111 assert!(hits[0].similarity.abs() < 1e-5);
112 }
113}