1use std::time::SystemTime;
2
3use domain::error::Result;
4use domain::model::EmbeddingEntry;
5use domain::ports::VectorStore;
6
7use crate::mapping::map_rusqlite_error;
8use crate::SqliteStore;
9
10const PROVIDER: &str = "all-MiniLM-L6-v2";
11
12fn now_rfc3339() -> String {
13 let duration = SystemTime::now()
14 .duration_since(SystemTime::UNIX_EPOCH)
15 .unwrap_or_default();
16 let secs = duration.as_secs();
17 let (days, rem) = (secs / 86400, secs % 86400);
19 let (hours, rem) = (rem / 3600, rem % 3600);
20 let (mins, s) = (rem / 60, rem % 60);
21 let mut y = 1970i64;
23 let mut d = days as i64;
24 loop {
25 let year_days = if y % 4 == 0 && (y % 100 != 0 || y % 400 == 0) {
26 366
27 } else {
28 365
29 };
30 if d < year_days {
31 break;
32 }
33 d -= year_days;
34 y += 1;
35 }
36 let leap = y % 4 == 0 && (y % 100 != 0 || y % 400 == 0);
37 let month_days = [
38 31,
39 if leap { 29 } else { 28 },
40 31,
41 30,
42 31,
43 30,
44 31,
45 31,
46 30,
47 31,
48 30,
49 31,
50 ];
51 let mut m = 0;
52 for md in month_days {
53 if d < md {
54 break;
55 }
56 d -= md;
57 m += 1;
58 }
59 format!(
60 "{y:04}-{:02}-{:02}T{hours:02}:{mins:02}:{s:02}Z",
61 m + 1,
62 d + 1
63 )
64}
65
66fn pack_f32(vec: &[f32]) -> Vec<u8> {
67 let mut buf = Vec::with_capacity(vec.len() * 4);
68 for &v in vec {
69 buf.extend_from_slice(&v.to_le_bytes());
70 }
71 buf
72}
73
74fn unpack_f32(blob: &[u8]) -> Vec<f32> {
75 blob.chunks_exact(4)
76 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
77 .collect()
78}
79
80fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
81 let mut dot = 0.0f64;
82 let mut norm_a = 0.0f64;
83 let mut norm_b = 0.0f64;
84 for i in 0..a.len() {
85 let ai = a[i] as f64;
86 let bi = b[i] as f64;
87 dot += ai * bi;
88 norm_a += ai * ai;
89 norm_b += bi * bi;
90 }
91 let denom = norm_a.sqrt() * norm_b.sqrt();
92 if denom == 0.0 {
93 0.0
94 } else {
95 dot / denom
96 }
97}
98
99impl VectorStore for SqliteStore {
100 fn store_embeddings(&self, entries: &[EmbeddingEntry]) -> Result<()> {
101 if entries.is_empty() {
102 return Ok(());
103 }
104 let conn = self.conn()?;
105 let created_at = now_rfc3339();
106 let mut stmt = conn
107 .prepare_cached(
108 "INSERT OR REPLACE INTO embeddings (qualified_name, vector, text_hash, provider, created_at)
109 VALUES (?1, ?2, ?3, ?4, ?5)",
110 )
111 .map_err(map_rusqlite_error)?;
112 for entry in entries {
113 let blob = pack_f32(&entry.vector);
114 stmt.execute(rusqlite::params![
115 &entry.qualified_name,
116 blob,
117 &entry.text_hash,
118 PROVIDER,
119 &created_at,
120 ])
121 .map_err(map_rusqlite_error)?;
122 }
123 Ok(())
124 }
125
126 fn search_nearest(&self, query_vec: &[f32], limit: usize) -> Result<Vec<(String, f64)>> {
127 let conn = self.conn()?;
128 let mut stmt = conn
129 .prepare_cached("SELECT qualified_name, vector FROM embeddings")
130 .map_err(map_rusqlite_error)?;
131 let rows = stmt
132 .query_map([], |row| {
133 Ok((row.get::<_, String>(0)?, row.get::<_, Vec<u8>>(1)?))
134 })
135 .map_err(map_rusqlite_error)?;
136
137 let mut scored: Vec<(String, f64)> = Vec::new();
138 for row in rows {
139 let (qn, blob) = row.map_err(map_rusqlite_error)?;
140 let vec = unpack_f32(&blob);
141 let sim = cosine_similarity(query_vec, &vec);
142 scored.push((qn, sim));
143 }
144
145 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
146 scored.truncate(limit);
147 Ok(scored)
148 }
149
150 fn has_embeddings(&self) -> bool {
151 self.conn()
152 .ok()
153 .and_then(|conn| {
154 conn.query_row("SELECT EXISTS(SELECT 1 FROM embeddings)", [], |r| {
155 r.get::<_, i32>(0)
156 })
157 .ok()
158 })
159 .map(|v| v != 0)
160 .unwrap_or(false)
161 }
162
163 fn count(&self) -> Result<usize> {
164 let conn = self.conn()?;
165 let n: i64 = conn
166 .query_row("SELECT COUNT(*) FROM embeddings", [], |r| r.get(0))
167 .map_err(map_rusqlite_error)?;
168 Ok(n as usize)
169 }
170
171 fn remove_embeddings(&self, qualified_names: &[&str]) -> Result<()> {
172 if qualified_names.is_empty() {
173 return Ok(());
174 }
175 let conn = self.conn()?;
176 let placeholders: String = (1..=qualified_names.len())
179 .map(|i| format!("?{i}"))
180 .collect::<Vec<_>>()
181 .join(", ");
182 let sql = format!("DELETE FROM embeddings WHERE qualified_name IN ({placeholders})");
183 let mut stmt = conn.prepare(&sql).map_err(map_rusqlite_error)?;
184 stmt.execute(rusqlite::params_from_iter(qualified_names.iter()))
185 .map_err(map_rusqlite_error)?;
186 Ok(())
187 }
188
189 fn get_stored_hashes(&self) -> Result<Vec<(String, String)>> {
190 let conn = self.conn()?;
191 let mut stmt = conn
192 .prepare_cached("SELECT qualified_name, text_hash FROM embeddings")
193 .map_err(map_rusqlite_error)?;
194 let rows = stmt
195 .query_map([], |row| {
196 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
197 })
198 .map_err(map_rusqlite_error)?;
199 let mut result = Vec::new();
200 for row in rows {
201 result.push(row.map_err(map_rusqlite_error)?);
202 }
203 Ok(result)
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use domain::model::{FileNode, Language, Location, SymbolKind, SymbolNode, Visibility};
211 use domain::ports::{GraphStore, VectorStore};
212
213 fn setup() -> SqliteStore {
214 SqliteStore::open_in_memory().unwrap()
215 }
216
217 fn make_entry(qn: &str, vec: Vec<f32>) -> EmbeddingEntry {
218 EmbeddingEntry {
219 qualified_name: qn.to_string(),
220 vector: vec,
221 text_hash: format!("hash_{qn}"),
222 }
223 }
224
225 fn insert_symbol(store: &SqliteStore, file_path: &str, qn: &str) {
227 let file = FileNode {
228 path: file_path.into(),
229 language: Language::Rust,
230 hash: "h".into(),
231 };
232 store.upsert_file(&file).unwrap();
233 let sym = SymbolNode {
234 name: qn.split("::").last().unwrap_or(qn).to_string(),
235 qualified_name: qn.to_string(),
236 kind: SymbolKind::Function,
237 location: Location {
238 file: file_path.into(),
239 line_start: 1,
240 line_end: 10,
241 col_start: 0,
242 col_end: 1,
243 },
244 visibility: Visibility::Public,
245 is_exported: true,
246 is_async: false,
247 is_test: false,
248 decorators: vec![],
249 signature: None,
250 };
251 store.upsert_symbol(&sym).unwrap();
252 }
253
254 #[test]
255 fn has_embeddings_false_when_empty() {
256 let store = setup();
257 assert!(!store.has_embeddings());
258 }
259
260 #[test]
261 fn has_embeddings_true_after_store() {
262 let store = setup();
263 insert_symbol(&store, "src/a.rs", "src/a.rs::foo");
264 store
265 .store_embeddings(&[make_entry("src/a.rs::foo", vec![1.0, 0.0])])
266 .unwrap();
267 assert!(store.has_embeddings());
268 }
269
270 #[test]
271 fn count_returns_correct_number() {
272 let store = setup();
273 assert_eq!(store.count().unwrap(), 0);
274 insert_symbol(&store, "src/a.rs", "src/a.rs::foo");
275 insert_symbol(&store, "src/b.rs", "src/b.rs::bar");
276 store
277 .store_embeddings(&[
278 make_entry("src/a.rs::foo", vec![1.0, 0.0]),
279 make_entry("src/b.rs::bar", vec![0.0, 1.0]),
280 ])
281 .unwrap();
282 assert_eq!(store.count().unwrap(), 2);
283 }
284
285 #[test]
286 fn store_and_retrieve_embeddings() {
287 let store = setup();
288 insert_symbol(&store, "src/a.rs", "src/a.rs::foo");
289 insert_symbol(&store, "src/b.rs", "src/b.rs::bar");
290
291 store
292 .store_embeddings(&[
293 make_entry("src/a.rs::foo", vec![1.0, 0.0, 0.0]),
294 make_entry("src/b.rs::bar", vec![0.0, 1.0, 0.0]),
295 ])
296 .unwrap();
297
298 let results = store.search_nearest(&[1.0, 0.0, 0.0], 10).unwrap();
300 assert_eq!(results.len(), 2);
301 assert_eq!(results[0].0, "src/a.rs::foo");
302 assert!(results[0].1 > results[1].1);
303 }
304
305 #[test]
306 fn cosine_similarity_ranking() {
307 let store = setup();
308 insert_symbol(&store, "src/a.rs", "src/a.rs::close");
309 insert_symbol(&store, "src/b.rs", "src/b.rs::far");
310
311 store
313 .store_embeddings(&[
314 make_entry("src/a.rs::close", vec![0.9, 0.1]),
315 make_entry("src/b.rs::far", vec![0.0, 1.0]),
316 ])
317 .unwrap();
318
319 let results = store.search_nearest(&[1.0, 0.0], 2).unwrap();
320 assert_eq!(results[0].0, "src/a.rs::close");
321 assert!(results[0].1 > results[1].1);
322 }
323
324 #[test]
325 fn remove_embeddings_deletes_entries() {
326 let store = setup();
327 insert_symbol(&store, "src/a.rs", "src/a.rs::foo");
328 insert_symbol(&store, "src/b.rs", "src/b.rs::bar");
329
330 store
331 .store_embeddings(&[
332 make_entry("src/a.rs::foo", vec![1.0, 0.0]),
333 make_entry("src/b.rs::bar", vec![0.0, 1.0]),
334 ])
335 .unwrap();
336 assert_eq!(store.count().unwrap(), 2);
337
338 store.remove_embeddings(&["src/a.rs::foo"]).unwrap();
339 assert_eq!(store.count().unwrap(), 1);
340
341 let results = store.search_nearest(&[1.0, 0.0], 10).unwrap();
342 assert_eq!(results.len(), 1);
343 assert_eq!(results[0].0, "src/b.rs::bar");
344 }
345
346 #[test]
347 fn store_embeddings_upserts() {
348 let store = setup();
349 insert_symbol(&store, "src/a.rs", "src/a.rs::foo");
350
351 store
353 .store_embeddings(&[make_entry("src/a.rs::foo", vec![1.0, 0.0])])
354 .unwrap();
355 assert_eq!(store.count().unwrap(), 1);
356
357 store
359 .store_embeddings(&[make_entry("src/a.rs::foo", vec![0.0, 1.0])])
360 .unwrap();
361 assert_eq!(store.count().unwrap(), 1);
362
363 let results = store.search_nearest(&[0.0, 1.0], 1).unwrap();
365 assert_eq!(results.len(), 1);
366 assert!((results[0].1 - 1.0).abs() < 1e-6);
367 }
368}