1use std::collections::HashMap;
10
11#[derive(Debug, thiserror::Error)]
13pub enum SemanticSearchError {
14 #[error("dimension mismatch: expected {expected}, got {actual}")]
16 DimensionMismatch {
17 expected: usize,
19 actual: usize,
21 },
22
23 #[error("index full: capacity is {0}")]
25 IndexFull(usize),
26
27 #[error("semantic index error: {0}")]
29 Internal(String),
30}
31
32pub struct SemanticIndex {
34 vectors: HashMap<i64, Vec<f32>>,
35 dimension: usize,
36 model_id: String,
37 capacity: usize,
38}
39
40impl SemanticIndex {
41 pub fn new(dimension: usize, model_id: String, capacity: usize) -> Self {
43 Self {
44 vectors: HashMap::with_capacity(capacity.min(1024)),
45 dimension,
46 model_id,
47 capacity,
48 }
49 }
50
51 pub fn insert(
53 &mut self,
54 chunk_id: i64,
55 embedding: Vec<f32>,
56 ) -> Result<(), SemanticSearchError> {
57 if embedding.len() != self.dimension {
58 return Err(SemanticSearchError::DimensionMismatch {
59 expected: self.dimension,
60 actual: embedding.len(),
61 });
62 }
63
64 if !self.vectors.contains_key(&chunk_id) && self.vectors.len() >= self.capacity {
65 return Err(SemanticSearchError::IndexFull(self.capacity));
66 }
67
68 self.vectors.insert(chunk_id, embedding);
69 Ok(())
70 }
71
72 pub fn remove(&mut self, chunk_id: i64) -> bool {
74 self.vectors.remove(&chunk_id).is_some()
75 }
76
77 pub fn search(&self, query: &[f32], k: usize) -> Vec<(i64, f32)> {
82 if self.vectors.is_empty() || k == 0 {
83 return vec![];
84 }
85
86 let mut scored: Vec<(i64, f32)> = self
87 .vectors
88 .iter()
89 .map(|(&chunk_id, vec)| {
90 let dist = cosine_distance(query, vec);
91 (chunk_id, dist)
92 })
93 .collect();
94
95 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
96 scored.truncate(k);
97 scored
98 }
99
100 pub fn len(&self) -> usize {
102 self.vectors.len()
103 }
104
105 pub fn is_empty(&self) -> bool {
107 self.vectors.is_empty()
108 }
109
110 pub fn model_id(&self) -> &str {
112 &self.model_id
113 }
114
115 pub fn dimension(&self) -> usize {
117 self.dimension
118 }
119
120 pub fn rebuild_from(&mut self, embeddings: Vec<(i64, Vec<f32>)>) {
122 self.vectors.clear();
123 for (chunk_id, vec) in embeddings {
124 if vec.len() == self.dimension {
125 self.vectors.insert(chunk_id, vec);
126 }
127 }
128 }
129}
130
131fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
135 let mut dot = 0.0_f32;
136 let mut norm_a = 0.0_f32;
137 let mut norm_b = 0.0_f32;
138
139 for (x, y) in a.iter().zip(b.iter()) {
140 dot += x * y;
141 norm_a += x * x;
142 norm_b += y * y;
143 }
144
145 let denom = norm_a.sqrt() * norm_b.sqrt();
146 if denom < f32::EPSILON {
147 return 1.0;
148 }
149
150 1.0 - (dot / denom)
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 fn make_vec(val: f32, dim: usize) -> Vec<f32> {
158 vec![val; dim]
159 }
160
161 #[test]
162 fn insert_and_search_finds_nearest() {
163 let mut idx = SemanticIndex::new(3, "test".to_string(), 100);
164 idx.insert(1, vec![1.0, 0.0, 0.0]).unwrap();
165 idx.insert(2, vec![0.0, 1.0, 0.0]).unwrap();
166 idx.insert(3, vec![0.9, 0.1, 0.0]).unwrap();
167
168 let results = idx.search(&[1.0, 0.0, 0.0], 2);
169 assert_eq!(results.len(), 2);
170 assert_eq!(results[0].0, 1);
172 assert!(results[0].1 < 0.01);
173 }
174
175 #[test]
176 fn search_returns_correct_top_k() {
177 let mut idx = SemanticIndex::new(2, "test".to_string(), 100);
178 for i in 0..10 {
179 idx.insert(i, vec![i as f32, 1.0]).unwrap();
180 }
181
182 let results = idx.search(&[9.0, 1.0], 3);
183 assert_eq!(results.len(), 3);
184 assert_eq!(results[0].0, 9); }
186
187 #[test]
188 fn remove_makes_vector_unfindable() {
189 let mut idx = SemanticIndex::new(2, "test".to_string(), 100);
190 idx.insert(1, vec![1.0, 0.0]).unwrap();
191 idx.insert(2, vec![0.0, 1.0]).unwrap();
192
193 assert!(idx.remove(1));
194 assert_eq!(idx.len(), 1);
195
196 let results = idx.search(&[1.0, 0.0], 10);
197 assert_eq!(results.len(), 1);
198 assert_eq!(results[0].0, 2);
199 }
200
201 #[test]
202 fn remove_nonexistent_returns_false() {
203 let mut idx = SemanticIndex::new(2, "test".to_string(), 100);
204 assert!(!idx.remove(999));
205 }
206
207 #[test]
208 fn dimension_mismatch_on_insert() {
209 let mut idx = SemanticIndex::new(3, "test".to_string(), 100);
210 let err = idx.insert(1, vec![1.0, 2.0]).unwrap_err();
211 matches!(
212 err,
213 SemanticSearchError::DimensionMismatch {
214 expected: 3,
215 actual: 2,
216 }
217 );
218 }
219
220 #[test]
221 fn empty_search_returns_empty() {
222 let idx = SemanticIndex::new(3, "test".to_string(), 100);
223 let results = idx.search(&[1.0, 0.0, 0.0], 5);
224 assert!(results.is_empty());
225 }
226
227 #[test]
228 fn search_with_k_zero_returns_empty() {
229 let mut idx = SemanticIndex::new(2, "test".to_string(), 100);
230 idx.insert(1, vec![1.0, 0.0]).unwrap();
231 let results = idx.search(&[1.0, 0.0], 0);
232 assert!(results.is_empty());
233 }
234
235 #[test]
236 fn rebuild_replaces_all_contents() {
237 let mut idx = SemanticIndex::new(2, "test".to_string(), 100);
238 idx.insert(1, vec![1.0, 0.0]).unwrap();
239 idx.insert(2, vec![0.0, 1.0]).unwrap();
240 assert_eq!(idx.len(), 2);
241
242 idx.rebuild_from(vec![(10, vec![0.5, 0.5]), (11, vec![0.3, 0.7])]);
243 assert_eq!(idx.len(), 2);
244 assert!(!idx.vectors.contains_key(&1));
245 assert!(idx.vectors.contains_key(&10));
246 }
247
248 #[test]
249 fn capacity_limit_respected() {
250 let mut idx = SemanticIndex::new(2, "test".to_string(), 2);
251 idx.insert(1, vec![1.0, 0.0]).unwrap();
252 idx.insert(2, vec![0.0, 1.0]).unwrap();
253
254 let err = idx.insert(3, vec![0.5, 0.5]).unwrap_err();
255 matches!(err, SemanticSearchError::IndexFull(2));
256 }
257
258 #[test]
259 fn overwrite_existing_does_not_count_as_new() {
260 let mut idx = SemanticIndex::new(2, "test".to_string(), 2);
261 idx.insert(1, vec![1.0, 0.0]).unwrap();
262 idx.insert(2, vec![0.0, 1.0]).unwrap();
263 idx.insert(1, vec![0.5, 0.5]).unwrap();
265 assert_eq!(idx.len(), 2);
266 }
267
268 #[test]
269 fn accessors() {
270 let idx = SemanticIndex::new(768, "nomic-embed-text".to_string(), 50_000);
271 assert_eq!(idx.dimension(), 768);
272 assert_eq!(idx.model_id(), "nomic-embed-text");
273 assert!(idx.is_empty());
274 assert_eq!(idx.len(), 0);
275 }
276
277 #[test]
278 fn cosine_distance_identical_vectors() {
279 let dist = cosine_distance(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0]);
280 assert!(dist.abs() < 1e-5);
281 }
282
283 #[test]
284 fn cosine_distance_orthogonal_vectors() {
285 let dist = cosine_distance(&[1.0, 0.0], &[0.0, 1.0]);
286 assert!((dist - 1.0).abs() < 1e-5);
287 }
288
289 #[test]
290 fn cosine_distance_opposite_vectors() {
291 let dist = cosine_distance(&[1.0, 0.0], &[-1.0, 0.0]);
292 assert!((dist - 2.0).abs() < 1e-5);
293 }
294
295 #[test]
296 fn cosine_distance_zero_vector() {
297 let dist = cosine_distance(&[0.0, 0.0], &[1.0, 1.0]);
298 assert!((dist - 1.0).abs() < 1e-5);
299 }
300
301 #[test]
302 fn rebuild_skips_wrong_dimension() {
303 let mut idx = SemanticIndex::new(3, "test".to_string(), 100);
304 idx.rebuild_from(vec![
305 (1, vec![1.0, 2.0, 3.0]), (2, vec![1.0, 2.0]), (3, make_vec(0.5, 3)), ]);
309 assert_eq!(idx.len(), 2);
310 assert!(idx.vectors.contains_key(&1));
311 assert!(!idx.vectors.contains_key(&2));
312 assert!(idx.vectors.contains_key(&3));
313 }
314}