Skip to main content

the_code_graph_storage/
embedding_store.rs

1use std::time::SystemTime;
2
3use domain::error::Result;
4use domain::model::EmbeddingEntry;
5use domain::ports::VectorStore;
6
7use crate::mapping::map_rusqlite_error;
8use crate::SqliteStore;
9
10const PROVIDER: &str = "all-MiniLM-L6-v2";
11
12fn now_rfc3339() -> String {
13    let duration = SystemTime::now()
14        .duration_since(SystemTime::UNIX_EPOCH)
15        .unwrap_or_default();
16    let secs = duration.as_secs();
17    // Simple ISO 8601 timestamp without pulling in chrono
18    let (days, rem) = (secs / 86400, secs % 86400);
19    let (hours, rem) = (rem / 3600, rem % 3600);
20    let (mins, s) = (rem / 60, rem % 60);
21    // Days since 1970-01-01, convert to y-m-d via a basic algorithm
22    let mut y = 1970i64;
23    let mut d = days as i64;
24    loop {
25        let year_days = if y % 4 == 0 && (y % 100 != 0 || y % 400 == 0) {
26            366
27        } else {
28            365
29        };
30        if d < year_days {
31            break;
32        }
33        d -= year_days;
34        y += 1;
35    }
36    let leap = y % 4 == 0 && (y % 100 != 0 || y % 400 == 0);
37    let month_days = [
38        31,
39        if leap { 29 } else { 28 },
40        31,
41        30,
42        31,
43        30,
44        31,
45        31,
46        30,
47        31,
48        30,
49        31,
50    ];
51    let mut m = 0;
52    for md in month_days {
53        if d < md {
54            break;
55        }
56        d -= md;
57        m += 1;
58    }
59    format!(
60        "{y:04}-{:02}-{:02}T{hours:02}:{mins:02}:{s:02}Z",
61        m + 1,
62        d + 1
63    )
64}
65
66fn pack_f32(vec: &[f32]) -> Vec<u8> {
67    let mut buf = Vec::with_capacity(vec.len() * 4);
68    for &v in vec {
69        buf.extend_from_slice(&v.to_le_bytes());
70    }
71    buf
72}
73
74fn unpack_f32(blob: &[u8]) -> Vec<f32> {
75    blob.chunks_exact(4)
76        .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
77        .collect()
78}
79
80fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
81    let mut dot = 0.0f64;
82    let mut norm_a = 0.0f64;
83    let mut norm_b = 0.0f64;
84    for i in 0..a.len() {
85        let ai = a[i] as f64;
86        let bi = b[i] as f64;
87        dot += ai * bi;
88        norm_a += ai * ai;
89        norm_b += bi * bi;
90    }
91    let denom = norm_a.sqrt() * norm_b.sqrt();
92    if denom == 0.0 {
93        0.0
94    } else {
95        dot / denom
96    }
97}
98
99impl VectorStore for SqliteStore {
100    fn store_embeddings(&self, entries: &[EmbeddingEntry]) -> Result<()> {
101        if entries.is_empty() {
102            return Ok(());
103        }
104        let conn = self.conn()?;
105        let created_at = now_rfc3339();
106        let mut stmt = conn
107            .prepare_cached(
108                "INSERT OR REPLACE INTO embeddings (qualified_name, vector, text_hash, provider, created_at)
109                 VALUES (?1, ?2, ?3, ?4, ?5)",
110            )
111            .map_err(map_rusqlite_error)?;
112        for entry in entries {
113            let blob = pack_f32(&entry.vector);
114            stmt.execute(rusqlite::params![
115                &entry.qualified_name,
116                blob,
117                &entry.text_hash,
118                PROVIDER,
119                &created_at,
120            ])
121            .map_err(map_rusqlite_error)?;
122        }
123        Ok(())
124    }
125
126    fn search_nearest(&self, query_vec: &[f32], limit: usize) -> Result<Vec<(String, f64)>> {
127        let conn = self.conn()?;
128        let mut stmt = conn
129            .prepare_cached("SELECT qualified_name, vector FROM embeddings")
130            .map_err(map_rusqlite_error)?;
131        let rows = stmt
132            .query_map([], |row| {
133                Ok((row.get::<_, String>(0)?, row.get::<_, Vec<u8>>(1)?))
134            })
135            .map_err(map_rusqlite_error)?;
136
137        let mut scored: Vec<(String, f64)> = Vec::new();
138        for row in rows {
139            let (qn, blob) = row.map_err(map_rusqlite_error)?;
140            let vec = unpack_f32(&blob);
141            let sim = cosine_similarity(query_vec, &vec);
142            scored.push((qn, sim));
143        }
144
145        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
146        scored.truncate(limit);
147        Ok(scored)
148    }
149
150    fn has_embeddings(&self) -> bool {
151        self.conn()
152            .ok()
153            .and_then(|conn| {
154                conn.query_row("SELECT EXISTS(SELECT 1 FROM embeddings)", [], |r| {
155                    r.get::<_, i32>(0)
156                })
157                .ok()
158            })
159            .map(|v| v != 0)
160            .unwrap_or(false)
161    }
162
163    fn count(&self) -> Result<usize> {
164        let conn = self.conn()?;
165        let n: i64 = conn
166            .query_row("SELECT COUNT(*) FROM embeddings", [], |r| r.get(0))
167            .map_err(map_rusqlite_error)?;
168        Ok(n as usize)
169    }
170
171    fn remove_embeddings(&self, qualified_names: &[&str]) -> Result<()> {
172        if qualified_names.is_empty() {
173            return Ok(());
174        }
175        let conn = self.conn()?;
176        // SAFETY: placeholders are numeric indices (?1, ?2, ...) derived from the slice
177        // length — no user data is interpolated into SQL. Values are bound via params_from_iter.
178        let placeholders: String = (1..=qualified_names.len())
179            .map(|i| format!("?{i}"))
180            .collect::<Vec<_>>()
181            .join(", ");
182        let sql = format!("DELETE FROM embeddings WHERE qualified_name IN ({placeholders})");
183        let mut stmt = conn.prepare(&sql).map_err(map_rusqlite_error)?;
184        stmt.execute(rusqlite::params_from_iter(qualified_names.iter()))
185            .map_err(map_rusqlite_error)?;
186        Ok(())
187    }
188
189    fn get_stored_hashes(&self) -> Result<Vec<(String, String)>> {
190        let conn = self.conn()?;
191        let mut stmt = conn
192            .prepare_cached("SELECT qualified_name, text_hash FROM embeddings")
193            .map_err(map_rusqlite_error)?;
194        let rows = stmt
195            .query_map([], |row| {
196                Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
197            })
198            .map_err(map_rusqlite_error)?;
199        let mut result = Vec::new();
200        for row in rows {
201            result.push(row.map_err(map_rusqlite_error)?);
202        }
203        Ok(result)
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use domain::model::{FileNode, Language, Location, SymbolKind, SymbolNode, Visibility};
211    use domain::ports::{GraphStore, VectorStore};
212
213    fn setup() -> SqliteStore {
214        SqliteStore::open_in_memory().unwrap()
215    }
216
217    fn make_entry(qn: &str, vec: Vec<f32>) -> EmbeddingEntry {
218        EmbeddingEntry {
219            qualified_name: qn.to_string(),
220            vector: vec,
221            text_hash: format!("hash_{qn}"),
222        }
223    }
224
225    /// Insert a file + symbol so the FK constraint on embeddings.qualified_name is satisfied.
226    fn insert_symbol(store: &SqliteStore, file_path: &str, qn: &str) {
227        let file = FileNode {
228            path: file_path.into(),
229            language: Language::Rust,
230            hash: "h".into(),
231        };
232        store.upsert_file(&file).unwrap();
233        let sym = SymbolNode {
234            name: qn.split("::").last().unwrap_or(qn).to_string(),
235            qualified_name: qn.to_string(),
236            kind: SymbolKind::Function,
237            location: Location {
238                file: file_path.into(),
239                line_start: 1,
240                line_end: 10,
241                col_start: 0,
242                col_end: 1,
243            },
244            visibility: Visibility::Public,
245            is_exported: true,
246            is_async: false,
247            is_test: false,
248            decorators: vec![],
249            signature: None,
250        };
251        store.upsert_symbol(&sym).unwrap();
252    }
253
254    #[test]
255    fn has_embeddings_false_when_empty() {
256        let store = setup();
257        assert!(!store.has_embeddings());
258    }
259
260    #[test]
261    fn has_embeddings_true_after_store() {
262        let store = setup();
263        insert_symbol(&store, "src/a.rs", "src/a.rs::foo");
264        store
265            .store_embeddings(&[make_entry("src/a.rs::foo", vec![1.0, 0.0])])
266            .unwrap();
267        assert!(store.has_embeddings());
268    }
269
270    #[test]
271    fn count_returns_correct_number() {
272        let store = setup();
273        assert_eq!(store.count().unwrap(), 0);
274        insert_symbol(&store, "src/a.rs", "src/a.rs::foo");
275        insert_symbol(&store, "src/b.rs", "src/b.rs::bar");
276        store
277            .store_embeddings(&[
278                make_entry("src/a.rs::foo", vec![1.0, 0.0]),
279                make_entry("src/b.rs::bar", vec![0.0, 1.0]),
280            ])
281            .unwrap();
282        assert_eq!(store.count().unwrap(), 2);
283    }
284
285    #[test]
286    fn store_and_retrieve_embeddings() {
287        let store = setup();
288        insert_symbol(&store, "src/a.rs", "src/a.rs::foo");
289        insert_symbol(&store, "src/b.rs", "src/b.rs::bar");
290
291        store
292            .store_embeddings(&[
293                make_entry("src/a.rs::foo", vec![1.0, 0.0, 0.0]),
294                make_entry("src/b.rs::bar", vec![0.0, 1.0, 0.0]),
295            ])
296            .unwrap();
297
298        // Query close to "foo"
299        let results = store.search_nearest(&[1.0, 0.0, 0.0], 10).unwrap();
300        assert_eq!(results.len(), 2);
301        assert_eq!(results[0].0, "src/a.rs::foo");
302        assert!(results[0].1 > results[1].1);
303    }
304
305    #[test]
306    fn cosine_similarity_ranking() {
307        let store = setup();
308        insert_symbol(&store, "src/a.rs", "src/a.rs::close");
309        insert_symbol(&store, "src/b.rs", "src/b.rs::far");
310
311        // "close" is near the query; "far" is orthogonal
312        store
313            .store_embeddings(&[
314                make_entry("src/a.rs::close", vec![0.9, 0.1]),
315                make_entry("src/b.rs::far", vec![0.0, 1.0]),
316            ])
317            .unwrap();
318
319        let results = store.search_nearest(&[1.0, 0.0], 2).unwrap();
320        assert_eq!(results[0].0, "src/a.rs::close");
321        assert!(results[0].1 > results[1].1);
322    }
323
324    #[test]
325    fn remove_embeddings_deletes_entries() {
326        let store = setup();
327        insert_symbol(&store, "src/a.rs", "src/a.rs::foo");
328        insert_symbol(&store, "src/b.rs", "src/b.rs::bar");
329
330        store
331            .store_embeddings(&[
332                make_entry("src/a.rs::foo", vec![1.0, 0.0]),
333                make_entry("src/b.rs::bar", vec![0.0, 1.0]),
334            ])
335            .unwrap();
336        assert_eq!(store.count().unwrap(), 2);
337
338        store.remove_embeddings(&["src/a.rs::foo"]).unwrap();
339        assert_eq!(store.count().unwrap(), 1);
340
341        let results = store.search_nearest(&[1.0, 0.0], 10).unwrap();
342        assert_eq!(results.len(), 1);
343        assert_eq!(results[0].0, "src/b.rs::bar");
344    }
345
346    #[test]
347    fn store_embeddings_upserts() {
348        let store = setup();
349        insert_symbol(&store, "src/a.rs", "src/a.rs::foo");
350
351        // First insert
352        store
353            .store_embeddings(&[make_entry("src/a.rs::foo", vec![1.0, 0.0])])
354            .unwrap();
355        assert_eq!(store.count().unwrap(), 1);
356
357        // Second insert with same qualified_name — should replace, not add
358        store
359            .store_embeddings(&[make_entry("src/a.rs::foo", vec![0.0, 1.0])])
360            .unwrap();
361        assert_eq!(store.count().unwrap(), 1);
362
363        // The stored vector should now be [0.0, 1.0]
364        let results = store.search_nearest(&[0.0, 1.0], 1).unwrap();
365        assert_eq!(results.len(), 1);
366        assert!((results[0].1 - 1.0).abs() < 1e-6);
367    }
368}