Skip to main content

synaptic_sqlite/
vectorstore.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use async_trait::async_trait;
5use rusqlite::Connection;
6use serde_json::Value;
7use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
8
9/// Configuration for [`SqliteVectorStore`].
10#[derive(Debug, Clone)]
11pub struct SqliteVectorStoreConfig {
12    /// Path to the SQLite database file. Use `":memory:"` for an in-memory database.
13    pub path: String,
14}
15
16impl SqliteVectorStoreConfig {
17    /// Create a new configuration with a file path.
18    pub fn new(path: impl Into<String>) -> Self {
19        Self { path: path.into() }
20    }
21
22    /// Create a configuration for an in-memory SQLite database.
23    pub fn in_memory() -> Self {
24        Self {
25            path: ":memory:".to_string(),
26        }
27    }
28}
29
30/// SQLite-backed implementation of the [`VectorStore`] trait.
31///
32/// Stores document embeddings as BLOBs (little-endian f32 sequences) and
33/// computes cosine similarity in Rust. An FTS5 virtual table provides
34/// full-text search for [`hybrid_search`](SqliteVectorStore::hybrid_search).
35pub struct SqliteVectorStore {
36    conn: Arc<Mutex<Connection>>,
37}
38
39impl SqliteVectorStore {
40    /// Create a new `SqliteVectorStore` from the given configuration.
41    ///
42    /// Opens (or creates) the SQLite database and initializes the vectors
43    /// and FTS5 tables if they do not already exist.
44    pub fn new(config: SqliteVectorStoreConfig) -> Result<Self, SynapticError> {
45        let conn = Connection::open(&config.path)
46            .map_err(|e| SynapticError::VectorStore(format!("SQLite open error: {e}")))?;
47
48        conn.execute_batch(
49            "CREATE TABLE IF NOT EXISTS synaptic_vectors (
50                id        TEXT PRIMARY KEY,
51                content   TEXT NOT NULL,
52                metadata  TEXT NOT NULL DEFAULT '{}',
53                embedding BLOB
54            );
55            CREATE VIRTUAL TABLE IF NOT EXISTS synaptic_vectors_fts USING fts5(
56                content, id UNINDEXED
57            );",
58        )
59        .map_err(|e| SynapticError::VectorStore(format!("SQLite create table error: {e}")))?;
60
61        Ok(Self {
62            conn: Arc::new(Mutex::new(conn)),
63        })
64    }
65
66    /// Hybrid search combining cosine similarity and BM25 full-text scoring.
67    ///
68    /// `alpha` controls the balance:
69    /// - `1.0` = pure vector similarity
70    /// - `0.0` = pure BM25 text relevance
71    /// - `0.5` = balanced (typical default)
72    ///
73    /// The final score is `alpha * cosine + (1 - alpha) * normalized_bm25`.
74    pub async fn hybrid_search(
75        &self,
76        query: &str,
77        k: usize,
78        embeddings: &dyn Embeddings,
79        alpha: f32,
80    ) -> Result<Vec<(Document, f32)>, SynapticError> {
81        let query_vec = embeddings.embed_query(query).await?;
82        let conn = self.conn.clone();
83        let query = query.to_string();
84
85        tokio::task::spawn_blocking(move || {
86            let conn = conn
87                .lock()
88                .map_err(|e| SynapticError::VectorStore(format!("lock error: {e}")))?;
89
90            // Step 1: FTS5 MATCH to get text-relevant docs with BM25 scores.
91            // bm25() returns negative values (closer to 0 = better match).
92            let fts_results: HashMap<String, f64> = {
93                let mut stmt = conn
94                    .prepare(
95                        "SELECT id, bm25(synaptic_vectors_fts) as score
96                         FROM synaptic_vectors_fts WHERE synaptic_vectors_fts MATCH ?1",
97                    )
98                    .map_err(|e| {
99                        SynapticError::VectorStore(format!("SQLite FTS prepare error: {e}"))
100                    })?;
101
102                let rows: Vec<(String, f64)> = stmt
103                    .query_map(rusqlite::params![query], |row| {
104                        Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?))
105                    })
106                    .map_err(|e| {
107                        SynapticError::VectorStore(format!("SQLite FTS query error: {e}"))
108                    })?
109                    .filter_map(|r| r.ok())
110                    .collect();
111
112                rows.into_iter().collect()
113            };
114
115            // Step 2: Load all docs with embeddings for cosine scoring.
116            let mut stmt = conn
117                .prepare(
118                    "SELECT id, content, metadata, embedding FROM synaptic_vectors
119                     WHERE embedding IS NOT NULL",
120                )
121                .map_err(|e| SynapticError::VectorStore(format!("SQLite prepare error: {e}")))?;
122
123            let all_docs: Vec<(Document, Vec<f32>)> = stmt
124                .query_map([], |row| {
125                    Ok((
126                        row.get::<_, String>(0)?,
127                        row.get::<_, String>(1)?,
128                        row.get::<_, String>(2)?,
129                        row.get::<_, Vec<u8>>(3)?,
130                    ))
131                })
132                .map_err(|e| SynapticError::VectorStore(format!("SQLite query error: {e}")))?
133                .filter_map(|r| r.ok())
134                .map(|(id, content, meta_str, blob)| {
135                    let metadata: HashMap<String, Value> =
136                        serde_json::from_str(&meta_str).unwrap_or_default();
137                    let embedding = blob_to_embed(&blob);
138                    (
139                        Document {
140                            id,
141                            content,
142                            metadata,
143                        },
144                        embedding,
145                    )
146                })
147                .collect();
148
149            // Normalize BM25 scores to [0, 1] range.
150            // bm25() returns negative values; we negate and normalize.
151            let bm25_max = fts_results
152                .values()
153                .map(|s| -s) // negate: higher = better
154                .fold(f64::NEG_INFINITY, f64::max);
155            let bm25_max = if bm25_max <= 0.0 { 1.0 } else { bm25_max };
156
157            // Step 3: Compute hybrid scores.
158            let mut scored: Vec<(Document, f32)> = all_docs
159                .into_iter()
160                .map(|(doc, emb)| {
161                    let cosine = cosine_similarity(&query_vec, &emb);
162                    let bm25_raw = fts_results.get(&doc.id).copied().unwrap_or(0.0);
163                    let bm25_normalized = (-bm25_raw / bm25_max) as f32;
164                    let score = alpha * cosine + (1.0 - alpha) * bm25_normalized;
165                    (doc, score)
166                })
167                .collect();
168
169            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
170            scored.truncate(k);
171
172            Ok(scored)
173        })
174        .await
175        .map_err(|e| SynapticError::VectorStore(format!("spawn_blocking error: {e}")))?
176    }
177}
178
179#[async_trait]
180impl VectorStore for SqliteVectorStore {
181    async fn add_documents(
182        &self,
183        docs: Vec<Document>,
184        embeddings: &dyn Embeddings,
185    ) -> Result<Vec<String>, SynapticError> {
186        let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
187        let vectors = embeddings.embed_documents(&texts).await?;
188
189        let conn = self.conn.clone();
190
191        let docs_with_vecs: Vec<(Document, Vec<f32>)> = docs.into_iter().zip(vectors).collect();
192
193        tokio::task::spawn_blocking(move || {
194            let conn = conn
195                .lock()
196                .map_err(|e| SynapticError::VectorStore(format!("lock error: {e}")))?;
197
198            let mut ids = Vec::with_capacity(docs_with_vecs.len());
199
200            for (mut doc, embedding) in docs_with_vecs {
201                // Auto-assign UUID if id is empty.
202                if doc.id.is_empty() {
203                    doc.id = uuid::Uuid::new_v4().to_string();
204                }
205
206                let meta_str = serde_json::to_string(&doc.metadata).map_err(|e| {
207                    SynapticError::VectorStore(format!("JSON serialize error: {e}"))
208                })?;
209                let blob = embed_to_blob(&embedding);
210
211                conn.execute(
212                    "INSERT OR REPLACE INTO synaptic_vectors (id, content, metadata, embedding)
213                     VALUES (?1, ?2, ?3, ?4)",
214                    rusqlite::params![doc.id, doc.content, meta_str, blob],
215                )
216                .map_err(|e| SynapticError::VectorStore(format!("SQLite insert error: {e}")))?;
217
218                // Sync FTS: delete old then insert new.
219                conn.execute(
220                    "DELETE FROM synaptic_vectors_fts WHERE id = ?1",
221                    rusqlite::params![doc.id],
222                )
223                .map_err(|e| SynapticError::VectorStore(format!("SQLite FTS delete error: {e}")))?;
224
225                conn.execute(
226                    "INSERT INTO synaptic_vectors_fts (content, id) VALUES (?1, ?2)",
227                    rusqlite::params![doc.content, doc.id],
228                )
229                .map_err(|e| SynapticError::VectorStore(format!("SQLite FTS insert error: {e}")))?;
230
231                ids.push(doc.id);
232            }
233
234            Ok(ids)
235        })
236        .await
237        .map_err(|e| SynapticError::VectorStore(format!("spawn_blocking error: {e}")))?
238    }
239
240    async fn similarity_search(
241        &self,
242        query: &str,
243        k: usize,
244        embeddings: &dyn Embeddings,
245    ) -> Result<Vec<Document>, SynapticError> {
246        let results = self
247            .similarity_search_with_score(query, k, embeddings)
248            .await?;
249        Ok(results.into_iter().map(|(doc, _)| doc).collect())
250    }
251
252    async fn similarity_search_with_score(
253        &self,
254        query: &str,
255        k: usize,
256        embeddings: &dyn Embeddings,
257    ) -> Result<Vec<(Document, f32)>, SynapticError> {
258        let query_vec = embeddings.embed_query(query).await?;
259        self.similarity_search_by_vector_with_score(&query_vec, k)
260            .await
261    }
262
263    async fn similarity_search_by_vector(
264        &self,
265        embedding: &[f32],
266        k: usize,
267    ) -> Result<Vec<Document>, SynapticError> {
268        let results = self
269            .similarity_search_by_vector_with_score(embedding, k)
270            .await?;
271        Ok(results.into_iter().map(|(doc, _)| doc).collect())
272    }
273
274    async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
275        let conn = self.conn.clone();
276        let ids: Vec<String> = ids.iter().map(|s| s.to_string()).collect();
277
278        tokio::task::spawn_blocking(move || {
279            let conn = conn
280                .lock()
281                .map_err(|e| SynapticError::VectorStore(format!("lock error: {e}")))?;
282
283            for id in &ids {
284                conn.execute(
285                    "DELETE FROM synaptic_vectors WHERE id = ?1",
286                    rusqlite::params![id],
287                )
288                .map_err(|e| SynapticError::VectorStore(format!("SQLite delete error: {e}")))?;
289
290                conn.execute(
291                    "DELETE FROM synaptic_vectors_fts WHERE id = ?1",
292                    rusqlite::params![id],
293                )
294                .map_err(|e| SynapticError::VectorStore(format!("SQLite FTS delete error: {e}")))?;
295            }
296
297            Ok(())
298        })
299        .await
300        .map_err(|e| SynapticError::VectorStore(format!("spawn_blocking error: {e}")))?
301    }
302}
303
304impl SqliteVectorStore {
305    /// Internal: similarity search by vector returning scores.
306    async fn similarity_search_by_vector_with_score(
307        &self,
308        embedding: &[f32],
309        k: usize,
310    ) -> Result<Vec<(Document, f32)>, SynapticError> {
311        let conn = self.conn.clone();
312        let query_vec = embedding.to_vec();
313
314        tokio::task::spawn_blocking(move || {
315            let conn = conn
316                .lock()
317                .map_err(|e| SynapticError::VectorStore(format!("lock error: {e}")))?;
318
319            let mut stmt = conn
320                .prepare(
321                    "SELECT id, content, metadata, embedding FROM synaptic_vectors
322                     WHERE embedding IS NOT NULL",
323                )
324                .map_err(|e| SynapticError::VectorStore(format!("SQLite prepare error: {e}")))?;
325
326            let mut scored: Vec<(Document, f32)> = stmt
327                .query_map([], |row| {
328                    Ok((
329                        row.get::<_, String>(0)?,
330                        row.get::<_, String>(1)?,
331                        row.get::<_, String>(2)?,
332                        row.get::<_, Vec<u8>>(3)?,
333                    ))
334                })
335                .map_err(|e| SynapticError::VectorStore(format!("SQLite query error: {e}")))?
336                .filter_map(|r| r.ok())
337                .map(|(id, content, meta_str, blob)| {
338                    let metadata: HashMap<String, Value> =
339                        serde_json::from_str(&meta_str).unwrap_or_default();
340                    let embedding = blob_to_embed(&blob);
341                    let score = cosine_similarity(&query_vec, &embedding);
342                    (
343                        Document {
344                            id,
345                            content,
346                            metadata,
347                        },
348                        score,
349                    )
350                })
351                .collect();
352
353            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
354            scored.truncate(k);
355
356            Ok(scored)
357        })
358        .await
359        .map_err(|e| SynapticError::VectorStore(format!("spawn_blocking error: {e}")))?
360    }
361}
362
363/// Serialize an embedding vector to a little-endian byte blob.
364fn embed_to_blob(embedding: &[f32]) -> Vec<u8> {
365    embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
366}
367
368/// Deserialize a little-endian byte blob to an embedding vector.
369fn blob_to_embed(blob: &[u8]) -> Vec<f32> {
370    blob.chunks_exact(4)
371        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
372        .collect()
373}
374
375/// Compute cosine similarity between two vectors.
376fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
377    if a.len() != b.len() || a.is_empty() {
378        return 0.0;
379    }
380
381    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
382    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
383    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
384
385    if mag_a == 0.0 || mag_b == 0.0 {
386        return 0.0;
387    }
388
389    dot / (mag_a * mag_b)
390}