Skip to main content

sapphire_retrieve/
db.rs

1//! Unified retrieve database: FTS5 + vector search.
2//!
3//! [`RetrieveDb`] is the main entry point.  It manages one of the available
4//! storage backends and exposes a unified API for file tracking, document
5//! management, full-text search, and vector search.
6
7use std::{
8    collections::HashMap,
9    path::{Path, PathBuf},
10    sync::{Arc, Mutex},
11};
12
13use crate::{
14    embed::Embedder,
15    error::Result,
16    retrieve_store::{
17        ChunkHit, Document, FileSearchResult, FtsQuery, HybridQuery, RetrieveStore, VectorQuery,
18    },
19    vector_store::VecInfo,
20};
21
22#[cfg(feature = "sqlite-store")]
23use crate::sqlite_store::SqliteStore;
24
25#[cfg(feature = "lancedb-store")]
26use crate::lancedb_store::LanceDbBackend;
27
28#[cfg(feature = "sqlite-store")]
29pub use crate::sqlite_store::SCHEMA_VERSION;
30
31// ── in-memory backend ─────────────────────────────────────────────────────────
32
33/// In-memory backend used when no persistent storage feature is compiled in.
34///
35/// Data lives in `HashMap`s and is lost when the process exits.
36struct InMemoryStore {
37    state: Mutex<InMemoryState>,
38}
39
40#[derive(Default)]
41struct InMemoryState {
42    files: HashMap<String, i64>,
43    documents: HashMap<i64, Document>,
44}
45
46impl InMemoryStore {
47    fn new() -> Self {
48        Self {
49            state: Mutex::new(InMemoryState::default()),
50        }
51    }
52}
53
54impl RetrieveStore for InMemoryStore {
55    fn file_mtimes(&self) -> Result<HashMap<String, i64>> {
56        Ok(self.state.lock().unwrap().files.clone())
57    }
58
59    fn upsert_file(&self, path: &str, mtime: i64) -> Result<()> {
60        self.state
61            .lock()
62            .unwrap()
63            .files
64            .insert(path.to_owned(), mtime);
65        Ok(())
66    }
67
68    fn remove_file(&self, path: &str) -> Result<()> {
69        self.state.lock().unwrap().files.remove(path);
70        Ok(())
71    }
72
73    fn file_count(&self) -> Result<u64> {
74        Ok(self.state.lock().unwrap().files.len() as u64)
75    }
76
77    fn upsert_document(&self, doc: &Document) -> Result<()> {
78        self.state
79            .lock()
80            .unwrap()
81            .documents
82            .insert(doc.id, doc.clone());
83        Ok(())
84    }
85
86    fn remove_document(&self, id: i64) -> Result<()> {
87        self.state.lock().unwrap().documents.remove(&id);
88        Ok(())
89    }
90
91    fn rebuild_fts(&self) -> Result<()> {
92        Ok(())
93    }
94
95    fn search_fts(&self, q: &FtsQuery<'_>) -> Result<Vec<FileSearchResult>> {
96        let state = self.state.lock().unwrap();
97        let needle = q.query.to_lowercase();
98        let prefix = q.path_prefix.map(|p| p.to_string_lossy().to_string());
99        let mut results: Vec<FileSearchResult> = state
100            .documents
101            .values()
102            .filter(|doc| {
103                if let Some(ref pfx) = prefix
104                    && !doc.path.starts_with(pfx.as_str())
105                {
106                    return false;
107                }
108                doc.body.to_lowercase().contains(&needle)
109            })
110            .take(q.limit)
111            .map(|doc| FileSearchResult {
112                id: doc.id,
113                path: doc.path.clone(),
114                score: 0.0,
115                chunks: vec![ChunkHit {
116                    line_start: 0,
117                    line_end: 0,
118                    text: String::new(),
119                    score: 0.0,
120                }],
121            })
122            .collect();
123        results.sort_by(|a, b| a.path.cmp(&b.path));
124        Ok(results)
125    }
126
127    fn document_ids(&self) -> Result<Vec<i64>> {
128        Ok(self
129            .state
130            .lock()
131            .unwrap()
132            .documents
133            .keys()
134            .copied()
135            .collect())
136    }
137
138    fn document_count(&self) -> Result<u64> {
139        Ok(self.state.lock().unwrap().documents.len() as u64)
140    }
141
142    fn embed_pending(
143        &self,
144        _embedder: &dyn Embedder,
145        _on_progress: &dyn Fn(usize, usize),
146    ) -> Result<usize> {
147        Ok(0)
148    }
149
150    fn vec_info(&self) -> Result<VecInfo> {
151        Ok(VecInfo {
152            embedding_dim: 0,
153            vector_count: 0,
154            pending_count: 0,
155        })
156    }
157
158    fn search_similar(&self, _q: &VectorQuery<'_>) -> Result<Vec<FileSearchResult>> {
159        Ok(vec![])
160    }
161}
162
163// ── backend state ─────────────────────────────────────────────────────────────
164
165enum BackendState {
166    #[allow(dead_code)]
167    InMemory(Arc<InMemoryStore>),
168    #[cfg(feature = "sqlite-store")]
169    Sqlite(Arc<SqliteStore>),
170    #[cfg(feature = "lancedb-store")]
171    LanceDb(Arc<LanceDbBackend>),
172}
173
174impl BackendState {
175    fn as_store(&self) -> Arc<dyn RetrieveStore> {
176        match self {
177            BackendState::InMemory(s) => Arc::clone(s) as Arc<dyn RetrieveStore>,
178            #[cfg(feature = "sqlite-store")]
179            BackendState::Sqlite(s) => Arc::clone(s) as Arc<dyn RetrieveStore>,
180            #[cfg(feature = "lancedb-store")]
181            BackendState::LanceDb(l) => Arc::clone(l) as Arc<dyn RetrieveStore>,
182        }
183    }
184
185    fn needs_init(&self) -> bool {
186        match self {
187            BackendState::InMemory(_) => true,
188            #[cfg(feature = "sqlite-store")]
189            BackendState::Sqlite(s) => s.dim().is_none(),
190            #[cfg(feature = "lancedb-store")]
191            BackendState::LanceDb(_) => false,
192        }
193    }
194}
195
196// ── RetrieveDb ────────────────────────────────────────────────────────────────
197
198pub struct RetrieveDb {
199    db_path: PathBuf,
200    backend: Mutex<BackendState>,
201}
202
203impl RetrieveDb {
204    pub fn open(db_path: &Path) -> Result<Self> {
205        #[cfg(feature = "sqlite-store")]
206        {
207            let store = SqliteStore::new_fts_only(db_path.to_owned());
208            Ok(Self {
209                db_path: db_path.to_owned(),
210                backend: Mutex::new(BackendState::Sqlite(Arc::new(store))),
211            })
212        }
213
214        #[cfg(not(feature = "sqlite-store"))]
215        Ok(Self {
216            db_path: db_path.to_owned(),
217            backend: Mutex::new(BackendState::InMemory(Arc::new(InMemoryStore::new()))),
218        })
219    }
220
221    pub fn rebuild(db_path: &Path) -> Result<Self> {
222        #[cfg(feature = "sqlite-store")]
223        crate::sqlite_store::wipe_db_files(db_path);
224        Self::open(db_path)
225    }
226
227    #[cfg(feature = "sqlite-store")]
228    pub fn init_sqlite_vec(&self, embedding_dim: u32) -> Result<()> {
229        let mut guard = self.backend.lock().unwrap();
230        if guard.needs_init() {
231            let store = SqliteStore::new_with_vec(self.db_path.clone(), embedding_dim)?;
232            *guard = BackendState::Sqlite(Arc::new(store));
233        }
234        Ok(())
235    }
236
237    #[cfg(feature = "lancedb-store")]
238    pub fn init_lancedb(&self, lancedb_dir: &Path, embedding_dim: u32) -> Result<()> {
239        let mut guard = self.backend.lock().unwrap();
240        if guard.needs_init() {
241            let backend = LanceDbBackend::new(lancedb_dir, embedding_dim)?;
242            *guard = BackendState::LanceDb(Arc::new(backend));
243        }
244        Ok(())
245    }
246
247    fn store(&self) -> Arc<dyn RetrieveStore> {
248        self.backend.lock().unwrap().as_store()
249    }
250
251    // ── document management ───────────────────────────────────────────────────
252
253    pub fn upsert_document(&self, doc: &Document) -> Result<()> {
254        self.store().upsert_document(doc)
255    }
256
257    pub fn remove_document(&self, id: i64) -> Result<()> {
258        self.store().remove_document(id)
259    }
260
261    pub fn rebuild_fts(&self) -> Result<()> {
262        self.store().rebuild_fts()
263    }
264
265    // ── search ────────────────────────────────────────────────────────────────
266
267    pub fn search_fts(&self, q: &FtsQuery<'_>) -> Result<Vec<FileSearchResult>> {
268        self.store().search_fts(q)
269    }
270
271    pub fn search_similar(&self, q: &VectorQuery<'_>) -> Result<Vec<FileSearchResult>> {
272        self.store().search_similar(q)
273    }
274
275    pub fn search_hybrid(&self, q: &HybridQuery<'_>) -> Result<Vec<FileSearchResult>> {
276        self.store().search_hybrid(q)
277    }
278
279    // ── embedding ─────────────────────────────────────────────────────────────
280
281    pub fn embed_pending(
282        &self,
283        embedder: &dyn Embedder,
284        on_progress: impl Fn(usize, usize),
285    ) -> Result<usize> {
286        self.store().embed_pending(embedder, &on_progress)
287    }
288
289    pub fn vec_info(&self) -> Result<VecInfo> {
290        self.store().vec_info()
291    }
292
293    pub fn document_ids(&self) -> Result<Vec<i64>> {
294        self.store().document_ids()
295    }
296
297    pub fn document_count(&self) -> Result<u64> {
298        self.store().document_count()
299    }
300
301    // ── file tracking ─────────────────────────────────────────────────────────
302
303    pub fn file_mtimes(&self) -> Result<HashMap<String, i64>> {
304        self.store().file_mtimes()
305    }
306
307    pub fn upsert_file(&self, path: &str, mtime: i64) -> Result<()> {
308        self.store().upsert_file(path, mtime)
309    }
310
311    pub fn remove_file(&self, path: &str) -> Result<()> {
312        self.store().remove_file(path)
313    }
314
315    pub fn file_count(&self) -> Result<u64> {
316        self.store().file_count()
317    }
318}
319
320// ── free functions ────────────────────────────────────────────────────────────
321
322/// Merge FTS and semantic file-level results via Reciprocal Rank Fusion.
323///
324/// `score(d) = w_fts / (k + rank_fts) + w_sem / (k + rank_sem)`.  Chunks from
325/// both inputs are merged (deduplicated by `(line_start, line_end)`, keeping
326/// the best per-chunk score).  Output is sorted by descending RRF score.
327pub fn merge_rrf_files(
328    fts: &[FileSearchResult],
329    sem: &[FileSearchResult],
330    k: f64,
331    w_fts: f64,
332    w_sem: f64,
333    limit: usize,
334) -> Vec<FileSearchResult> {
335    // Index FTS results by path (stable id alternative would work too).
336    let mut acc: HashMap<String, (FileSearchResult, f64)> = HashMap::new();
337
338    for (rank, file) in fts.iter().enumerate() {
339        let rrf = w_fts / (k + (rank + 1) as f64);
340        acc.insert(file.path.clone(), (file.clone(), rrf));
341    }
342
343    for (rank, file) in sem.iter().enumerate() {
344        let rrf = w_sem / (k + (rank + 1) as f64);
345        acc.entry(file.path.clone())
346            .and_modify(|(existing, s)| {
347                *s += rrf;
348                merge_chunk_hits(&mut existing.chunks, &file.chunks);
349            })
350            .or_insert_with(|| (file.clone(), rrf));
351    }
352
353    let mut merged: Vec<_> = acc.into_values().collect();
354    merged.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
355    merged.truncate(limit);
356
357    merged
358        .into_iter()
359        .map(|(mut file, rrf_score)| {
360            file.score = rrf_score;
361            file
362        })
363        .collect()
364}
365
366/// Merge `incoming` into `existing`, deduplicating by `(line_start, line_end)`.
367///
368/// When a chunk exists in both lists, the one from `existing` is kept (so FTS
369/// scores win over vector scores on the same chunk, which matches the order
370/// `merge_rrf_files` calls this).
371fn merge_chunk_hits(existing: &mut Vec<ChunkHit>, incoming: &[ChunkHit]) {
372    use std::collections::HashSet;
373    let seen: HashSet<(usize, usize)> = existing
374        .iter()
375        .map(|c| (c.line_start, c.line_end))
376        .collect();
377    for c in incoming {
378        if !seen.contains(&(c.line_start, c.line_end)) {
379            existing.push(c.clone());
380        }
381    }
382}
383
384/// Default hybrid search implementation used by [`RetrieveStore::search_hybrid`].
385///
386/// Calls `search_fts` and, when an embedder is provided, `search_similar`;
387/// then merges results via [`merge_rrf_files`].  When `q.embedder` is `None`,
388/// falls back to FTS-only output.
389pub fn default_hybrid<S: RetrieveStore + ?Sized>(
390    store: &S,
391    q: &HybridQuery<'_>,
392) -> Result<Vec<FileSearchResult>> {
393    let over_fetch = q.limit * 3;
394    let fts = store.search_fts(&FtsQuery {
395        query: q.query,
396        limit: over_fetch,
397        path_prefix: q.path_prefix,
398    })?;
399
400    let Some(embedder) = q.embedder else {
401        return Ok(fts.into_iter().take(q.limit).collect());
402    };
403
404    let sem = store.search_similar(&VectorQuery {
405        query: q.query,
406        embedder,
407        limit: over_fetch,
408        path_prefix: q.path_prefix,
409    })?;
410
411    Ok(merge_rrf_files(
412        &fts,
413        &sem,
414        q.rrf_k,
415        q.weight_fts,
416        q.weight_sem,
417        q.limit,
418    ))
419}
420
421// ── backend factory functions ─────────────────────────────────────────────────
422
423/// Open or create an in-memory backend.
424pub fn open_in_memory() -> Arc<dyn RetrieveStore + Send + Sync> {
425    Arc::new(InMemoryStore::new())
426}
427
428#[cfg(feature = "sqlite-store")]
429pub fn open_sqlite_fts(db_path: &Path) -> Arc<dyn RetrieveStore + Send + Sync> {
430    Arc::new(SqliteStore::new_fts_only(db_path.to_owned()))
431}
432
433#[cfg(feature = "sqlite-store")]
434pub fn open_sqlite_vec(db_path: &Path, dim: u32) -> Result<Arc<dyn RetrieveStore + Send + Sync>> {
435    Ok(Arc::new(SqliteStore::new_with_vec(
436        db_path.to_owned(),
437        dim,
438    )?))
439}
440
441#[cfg(feature = "lancedb-store")]
442pub fn open_lancedb(data_dir: &Path, dim: u32) -> Result<Arc<dyn RetrieveStore + Send + Sync>> {
443    Ok(Arc::new(LanceDbBackend::new(data_dir, dim)?))
444}