Skip to main content

tuitbot_core/context/
semantic_index.rs

1//! In-memory semantic search index with brute-force cosine similarity.
2//!
3//! Stores embedding vectors in memory for nearest-neighbor search. Uses a
4//! linear scan with cosine distance — acceptable for <50K vectors. The index
5//! can be swapped for an HNSW implementation later via the same API.
6//!
7//! SQLite is the source of truth; this index is rebuilt from DB at startup.
8
9use std::collections::HashMap;
10
11/// Error type for semantic index operations.
12#[derive(Debug, thiserror::Error)]
13pub enum SemanticSearchError {
14    /// Vector dimension does not match index configuration.
15    #[error("dimension mismatch: expected {expected}, got {actual}")]
16    DimensionMismatch {
17        /// Expected dimension.
18        expected: usize,
19        /// Actual dimension provided.
20        actual: usize,
21    },
22
23    /// Index has reached its configured capacity.
24    #[error("index full: capacity is {0}")]
25    IndexFull(usize),
26
27    /// Internal error.
28    #[error("semantic index error: {0}")]
29    Internal(String),
30}
31
32/// In-memory vector index for semantic search.
33pub struct SemanticIndex {
34    vectors: HashMap<i64, Vec<f32>>,
35    dimension: usize,
36    model_id: String,
37    capacity: usize,
38}
39
40impl SemanticIndex {
41    /// Create a new empty index.
42    pub fn new(dimension: usize, model_id: String, capacity: usize) -> Self {
43        Self {
44            vectors: HashMap::with_capacity(capacity.min(1024)),
45            dimension,
46            model_id,
47            capacity,
48        }
49    }
50
51    /// Insert a vector for a chunk. Overwrites if chunk_id already exists.
52    pub fn insert(
53        &mut self,
54        chunk_id: i64,
55        embedding: Vec<f32>,
56    ) -> Result<(), SemanticSearchError> {
57        if embedding.len() != self.dimension {
58            return Err(SemanticSearchError::DimensionMismatch {
59                expected: self.dimension,
60                actual: embedding.len(),
61            });
62        }
63
64        if !self.vectors.contains_key(&chunk_id) && self.vectors.len() >= self.capacity {
65            return Err(SemanticSearchError::IndexFull(self.capacity));
66        }
67
68        self.vectors.insert(chunk_id, embedding);
69        Ok(())
70    }
71
72    /// Remove a vector by chunk_id. Returns false if not found.
73    pub fn remove(&mut self, chunk_id: i64) -> bool {
74        self.vectors.remove(&chunk_id).is_some()
75    }
76
77    /// Search for the top-k nearest vectors by cosine similarity.
78    ///
79    /// Returns `(chunk_id, distance)` pairs sorted ascending by distance
80    /// (smaller = more similar). Distance = 1.0 - cosine_similarity.
81    pub fn search(&self, query: &[f32], k: usize) -> Vec<(i64, f32)> {
82        if self.vectors.is_empty() || k == 0 {
83            return vec![];
84        }
85
86        let mut scored: Vec<(i64, f32)> = self
87            .vectors
88            .iter()
89            .map(|(&chunk_id, vec)| {
90                let dist = cosine_distance(query, vec);
91                (chunk_id, dist)
92            })
93            .collect();
94
95        scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
96        scored.truncate(k);
97        scored
98    }
99
100    /// Number of vectors in the index.
101    pub fn len(&self) -> usize {
102        self.vectors.len()
103    }
104
105    /// Whether the index is empty.
106    pub fn is_empty(&self) -> bool {
107        self.vectors.is_empty()
108    }
109
110    /// Model identifier for this index.
111    pub fn model_id(&self) -> &str {
112        &self.model_id
113    }
114
115    /// Vector dimension for this index.
116    pub fn dimension(&self) -> usize {
117        self.dimension
118    }
119
120    /// Clear and rebuild the index from a batch of (chunk_id, embedding) pairs.
121    pub fn rebuild_from(&mut self, embeddings: Vec<(i64, Vec<f32>)>) {
122        self.vectors.clear();
123        for (chunk_id, vec) in embeddings {
124            if vec.len() == self.dimension {
125                self.vectors.insert(chunk_id, vec);
126            }
127        }
128    }
129}
130
131/// Compute cosine distance between two vectors: 1.0 - cosine_similarity.
132///
133/// Returns 1.0 (max distance) if either vector has zero magnitude.
134fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
135    let mut dot = 0.0_f32;
136    let mut norm_a = 0.0_f32;
137    let mut norm_b = 0.0_f32;
138
139    for (x, y) in a.iter().zip(b.iter()) {
140        dot += x * y;
141        norm_a += x * x;
142        norm_b += y * y;
143    }
144
145    let denom = norm_a.sqrt() * norm_b.sqrt();
146    if denom < f32::EPSILON {
147        return 1.0;
148    }
149
150    1.0 - (dot / denom)
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    fn make_vec(val: f32, dim: usize) -> Vec<f32> {
158        vec![val; dim]
159    }
160
161    #[test]
162    fn insert_and_search_finds_nearest() {
163        let mut idx = SemanticIndex::new(3, "test".to_string(), 100);
164        idx.insert(1, vec![1.0, 0.0, 0.0]).unwrap();
165        idx.insert(2, vec![0.0, 1.0, 0.0]).unwrap();
166        idx.insert(3, vec![0.9, 0.1, 0.0]).unwrap();
167
168        let results = idx.search(&[1.0, 0.0, 0.0], 2);
169        assert_eq!(results.len(), 2);
170        // chunk 1 should be the closest (exact match)
171        assert_eq!(results[0].0, 1);
172        assert!(results[0].1 < 0.01);
173    }
174
175    #[test]
176    fn search_returns_correct_top_k() {
177        let mut idx = SemanticIndex::new(2, "test".to_string(), 100);
178        for i in 0..10 {
179            idx.insert(i, vec![i as f32, 1.0]).unwrap();
180        }
181
182        let results = idx.search(&[9.0, 1.0], 3);
183        assert_eq!(results.len(), 3);
184        assert_eq!(results[0].0, 9); // closest
185    }
186
187    #[test]
188    fn remove_makes_vector_unfindable() {
189        let mut idx = SemanticIndex::new(2, "test".to_string(), 100);
190        idx.insert(1, vec![1.0, 0.0]).unwrap();
191        idx.insert(2, vec![0.0, 1.0]).unwrap();
192
193        assert!(idx.remove(1));
194        assert_eq!(idx.len(), 1);
195
196        let results = idx.search(&[1.0, 0.0], 10);
197        assert_eq!(results.len(), 1);
198        assert_eq!(results[0].0, 2);
199    }
200
201    #[test]
202    fn remove_nonexistent_returns_false() {
203        let mut idx = SemanticIndex::new(2, "test".to_string(), 100);
204        assert!(!idx.remove(999));
205    }
206
207    #[test]
208    fn dimension_mismatch_on_insert() {
209        let mut idx = SemanticIndex::new(3, "test".to_string(), 100);
210        let err = idx.insert(1, vec![1.0, 2.0]).unwrap_err();
211        matches!(
212            err,
213            SemanticSearchError::DimensionMismatch {
214                expected: 3,
215                actual: 2,
216            }
217        );
218    }
219
220    #[test]
221    fn empty_search_returns_empty() {
222        let idx = SemanticIndex::new(3, "test".to_string(), 100);
223        let results = idx.search(&[1.0, 0.0, 0.0], 5);
224        assert!(results.is_empty());
225    }
226
227    #[test]
228    fn search_with_k_zero_returns_empty() {
229        let mut idx = SemanticIndex::new(2, "test".to_string(), 100);
230        idx.insert(1, vec![1.0, 0.0]).unwrap();
231        let results = idx.search(&[1.0, 0.0], 0);
232        assert!(results.is_empty());
233    }
234
235    #[test]
236    fn rebuild_replaces_all_contents() {
237        let mut idx = SemanticIndex::new(2, "test".to_string(), 100);
238        idx.insert(1, vec![1.0, 0.0]).unwrap();
239        idx.insert(2, vec![0.0, 1.0]).unwrap();
240        assert_eq!(idx.len(), 2);
241
242        idx.rebuild_from(vec![(10, vec![0.5, 0.5]), (11, vec![0.3, 0.7])]);
243        assert_eq!(idx.len(), 2);
244        assert!(!idx.vectors.contains_key(&1));
245        assert!(idx.vectors.contains_key(&10));
246    }
247
248    #[test]
249    fn capacity_limit_respected() {
250        let mut idx = SemanticIndex::new(2, "test".to_string(), 2);
251        idx.insert(1, vec![1.0, 0.0]).unwrap();
252        idx.insert(2, vec![0.0, 1.0]).unwrap();
253
254        let err = idx.insert(3, vec![0.5, 0.5]).unwrap_err();
255        matches!(err, SemanticSearchError::IndexFull(2));
256    }
257
258    #[test]
259    fn overwrite_existing_does_not_count_as_new() {
260        let mut idx = SemanticIndex::new(2, "test".to_string(), 2);
261        idx.insert(1, vec![1.0, 0.0]).unwrap();
262        idx.insert(2, vec![0.0, 1.0]).unwrap();
263        // Overwrite chunk 1 — should not trigger IndexFull
264        idx.insert(1, vec![0.5, 0.5]).unwrap();
265        assert_eq!(idx.len(), 2);
266    }
267
268    #[test]
269    fn accessors() {
270        let idx = SemanticIndex::new(768, "nomic-embed-text".to_string(), 50_000);
271        assert_eq!(idx.dimension(), 768);
272        assert_eq!(idx.model_id(), "nomic-embed-text");
273        assert!(idx.is_empty());
274        assert_eq!(idx.len(), 0);
275    }
276
277    #[test]
278    fn cosine_distance_identical_vectors() {
279        let dist = cosine_distance(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0]);
280        assert!(dist.abs() < 1e-5);
281    }
282
283    #[test]
284    fn cosine_distance_orthogonal_vectors() {
285        let dist = cosine_distance(&[1.0, 0.0], &[0.0, 1.0]);
286        assert!((dist - 1.0).abs() < 1e-5);
287    }
288
289    #[test]
290    fn cosine_distance_opposite_vectors() {
291        let dist = cosine_distance(&[1.0, 0.0], &[-1.0, 0.0]);
292        assert!((dist - 2.0).abs() < 1e-5);
293    }
294
295    #[test]
296    fn cosine_distance_zero_vector() {
297        let dist = cosine_distance(&[0.0, 0.0], &[1.0, 1.0]);
298        assert!((dist - 1.0).abs() < 1e-5);
299    }
300
301    #[test]
302    fn rebuild_skips_wrong_dimension() {
303        let mut idx = SemanticIndex::new(3, "test".to_string(), 100);
304        idx.rebuild_from(vec![
305            (1, vec![1.0, 2.0, 3.0]), // correct dim
306            (2, vec![1.0, 2.0]),      // wrong dim — skipped
307            (3, make_vec(0.5, 3)),    // correct dim
308        ]);
309        assert_eq!(idx.len(), 2);
310        assert!(idx.vectors.contains_key(&1));
311        assert!(!idx.vectors.contains_key(&2));
312        assert!(idx.vectors.contains_key(&3));
313    }
314}