1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use async_trait::async_trait;
5use rusqlite::Connection;
6use serde_json::Value;
7use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
8
9#[derive(Debug, Clone)]
11pub struct SqliteVectorStoreConfig {
12 pub path: String,
14}
15
16impl SqliteVectorStoreConfig {
17 pub fn new(path: impl Into<String>) -> Self {
19 Self { path: path.into() }
20 }
21
22 pub fn in_memory() -> Self {
24 Self {
25 path: ":memory:".to_string(),
26 }
27 }
28}
29
30pub struct SqliteVectorStore {
36 conn: Arc<Mutex<Connection>>,
37}
38
39impl SqliteVectorStore {
40 pub fn new(config: SqliteVectorStoreConfig) -> Result<Self, SynapticError> {
45 let conn = Connection::open(&config.path)
46 .map_err(|e| SynapticError::VectorStore(format!("SQLite open error: {e}")))?;
47
48 conn.execute_batch(
49 "CREATE TABLE IF NOT EXISTS synaptic_vectors (
50 id TEXT PRIMARY KEY,
51 content TEXT NOT NULL,
52 metadata TEXT NOT NULL DEFAULT '{}',
53 embedding BLOB
54 );
55 CREATE VIRTUAL TABLE IF NOT EXISTS synaptic_vectors_fts USING fts5(
56 content, id UNINDEXED
57 );",
58 )
59 .map_err(|e| SynapticError::VectorStore(format!("SQLite create table error: {e}")))?;
60
61 Ok(Self {
62 conn: Arc::new(Mutex::new(conn)),
63 })
64 }
65
66 pub async fn hybrid_search(
75 &self,
76 query: &str,
77 k: usize,
78 embeddings: &dyn Embeddings,
79 alpha: f32,
80 ) -> Result<Vec<(Document, f32)>, SynapticError> {
81 let query_vec = embeddings.embed_query(query).await?;
82 let conn = self.conn.clone();
83 let query = query.to_string();
84
85 tokio::task::spawn_blocking(move || {
86 let conn = conn
87 .lock()
88 .map_err(|e| SynapticError::VectorStore(format!("lock error: {e}")))?;
89
90 let fts_results: HashMap<String, f64> = {
93 let mut stmt = conn
94 .prepare(
95 "SELECT id, bm25(synaptic_vectors_fts) as score
96 FROM synaptic_vectors_fts WHERE synaptic_vectors_fts MATCH ?1",
97 )
98 .map_err(|e| {
99 SynapticError::VectorStore(format!("SQLite FTS prepare error: {e}"))
100 })?;
101
102 let rows: Vec<(String, f64)> = stmt
103 .query_map(rusqlite::params![query], |row| {
104 Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?))
105 })
106 .map_err(|e| {
107 SynapticError::VectorStore(format!("SQLite FTS query error: {e}"))
108 })?
109 .filter_map(|r| r.ok())
110 .collect();
111
112 rows.into_iter().collect()
113 };
114
115 let mut stmt = conn
117 .prepare(
118 "SELECT id, content, metadata, embedding FROM synaptic_vectors
119 WHERE embedding IS NOT NULL",
120 )
121 .map_err(|e| SynapticError::VectorStore(format!("SQLite prepare error: {e}")))?;
122
123 let all_docs: Vec<(Document, Vec<f32>)> = stmt
124 .query_map([], |row| {
125 Ok((
126 row.get::<_, String>(0)?,
127 row.get::<_, String>(1)?,
128 row.get::<_, String>(2)?,
129 row.get::<_, Vec<u8>>(3)?,
130 ))
131 })
132 .map_err(|e| SynapticError::VectorStore(format!("SQLite query error: {e}")))?
133 .filter_map(|r| r.ok())
134 .map(|(id, content, meta_str, blob)| {
135 let metadata: HashMap<String, Value> =
136 serde_json::from_str(&meta_str).unwrap_or_default();
137 let embedding = blob_to_embed(&blob);
138 (
139 Document {
140 id,
141 content,
142 metadata,
143 },
144 embedding,
145 )
146 })
147 .collect();
148
149 let bm25_max = fts_results
152 .values()
153 .map(|s| -s) .fold(f64::NEG_INFINITY, f64::max);
155 let bm25_max = if bm25_max <= 0.0 { 1.0 } else { bm25_max };
156
157 let mut scored: Vec<(Document, f32)> = all_docs
159 .into_iter()
160 .map(|(doc, emb)| {
161 let cosine = cosine_similarity(&query_vec, &emb);
162 let bm25_raw = fts_results.get(&doc.id).copied().unwrap_or(0.0);
163 let bm25_normalized = (-bm25_raw / bm25_max) as f32;
164 let score = alpha * cosine + (1.0 - alpha) * bm25_normalized;
165 (doc, score)
166 })
167 .collect();
168
169 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
170 scored.truncate(k);
171
172 Ok(scored)
173 })
174 .await
175 .map_err(|e| SynapticError::VectorStore(format!("spawn_blocking error: {e}")))?
176 }
177}
178
179#[async_trait]
180impl VectorStore for SqliteVectorStore {
181 async fn add_documents(
182 &self,
183 docs: Vec<Document>,
184 embeddings: &dyn Embeddings,
185 ) -> Result<Vec<String>, SynapticError> {
186 let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
187 let vectors = embeddings.embed_documents(&texts).await?;
188
189 let conn = self.conn.clone();
190
191 let docs_with_vecs: Vec<(Document, Vec<f32>)> = docs.into_iter().zip(vectors).collect();
192
193 tokio::task::spawn_blocking(move || {
194 let conn = conn
195 .lock()
196 .map_err(|e| SynapticError::VectorStore(format!("lock error: {e}")))?;
197
198 let mut ids = Vec::with_capacity(docs_with_vecs.len());
199
200 for (mut doc, embedding) in docs_with_vecs {
201 if doc.id.is_empty() {
203 doc.id = uuid::Uuid::new_v4().to_string();
204 }
205
206 let meta_str = serde_json::to_string(&doc.metadata).map_err(|e| {
207 SynapticError::VectorStore(format!("JSON serialize error: {e}"))
208 })?;
209 let blob = embed_to_blob(&embedding);
210
211 conn.execute(
212 "INSERT OR REPLACE INTO synaptic_vectors (id, content, metadata, embedding)
213 VALUES (?1, ?2, ?3, ?4)",
214 rusqlite::params![doc.id, doc.content, meta_str, blob],
215 )
216 .map_err(|e| SynapticError::VectorStore(format!("SQLite insert error: {e}")))?;
217
218 conn.execute(
220 "DELETE FROM synaptic_vectors_fts WHERE id = ?1",
221 rusqlite::params![doc.id],
222 )
223 .map_err(|e| SynapticError::VectorStore(format!("SQLite FTS delete error: {e}")))?;
224
225 conn.execute(
226 "INSERT INTO synaptic_vectors_fts (content, id) VALUES (?1, ?2)",
227 rusqlite::params![doc.content, doc.id],
228 )
229 .map_err(|e| SynapticError::VectorStore(format!("SQLite FTS insert error: {e}")))?;
230
231 ids.push(doc.id);
232 }
233
234 Ok(ids)
235 })
236 .await
237 .map_err(|e| SynapticError::VectorStore(format!("spawn_blocking error: {e}")))?
238 }
239
240 async fn similarity_search(
241 &self,
242 query: &str,
243 k: usize,
244 embeddings: &dyn Embeddings,
245 ) -> Result<Vec<Document>, SynapticError> {
246 let results = self
247 .similarity_search_with_score(query, k, embeddings)
248 .await?;
249 Ok(results.into_iter().map(|(doc, _)| doc).collect())
250 }
251
252 async fn similarity_search_with_score(
253 &self,
254 query: &str,
255 k: usize,
256 embeddings: &dyn Embeddings,
257 ) -> Result<Vec<(Document, f32)>, SynapticError> {
258 let query_vec = embeddings.embed_query(query).await?;
259 self.similarity_search_by_vector_with_score(&query_vec, k)
260 .await
261 }
262
263 async fn similarity_search_by_vector(
264 &self,
265 embedding: &[f32],
266 k: usize,
267 ) -> Result<Vec<Document>, SynapticError> {
268 let results = self
269 .similarity_search_by_vector_with_score(embedding, k)
270 .await?;
271 Ok(results.into_iter().map(|(doc, _)| doc).collect())
272 }
273
274 async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
275 let conn = self.conn.clone();
276 let ids: Vec<String> = ids.iter().map(|s| s.to_string()).collect();
277
278 tokio::task::spawn_blocking(move || {
279 let conn = conn
280 .lock()
281 .map_err(|e| SynapticError::VectorStore(format!("lock error: {e}")))?;
282
283 for id in &ids {
284 conn.execute(
285 "DELETE FROM synaptic_vectors WHERE id = ?1",
286 rusqlite::params![id],
287 )
288 .map_err(|e| SynapticError::VectorStore(format!("SQLite delete error: {e}")))?;
289
290 conn.execute(
291 "DELETE FROM synaptic_vectors_fts WHERE id = ?1",
292 rusqlite::params![id],
293 )
294 .map_err(|e| SynapticError::VectorStore(format!("SQLite FTS delete error: {e}")))?;
295 }
296
297 Ok(())
298 })
299 .await
300 .map_err(|e| SynapticError::VectorStore(format!("spawn_blocking error: {e}")))?
301 }
302}
303
304impl SqliteVectorStore {
305 async fn similarity_search_by_vector_with_score(
307 &self,
308 embedding: &[f32],
309 k: usize,
310 ) -> Result<Vec<(Document, f32)>, SynapticError> {
311 let conn = self.conn.clone();
312 let query_vec = embedding.to_vec();
313
314 tokio::task::spawn_blocking(move || {
315 let conn = conn
316 .lock()
317 .map_err(|e| SynapticError::VectorStore(format!("lock error: {e}")))?;
318
319 let mut stmt = conn
320 .prepare(
321 "SELECT id, content, metadata, embedding FROM synaptic_vectors
322 WHERE embedding IS NOT NULL",
323 )
324 .map_err(|e| SynapticError::VectorStore(format!("SQLite prepare error: {e}")))?;
325
326 let mut scored: Vec<(Document, f32)> = stmt
327 .query_map([], |row| {
328 Ok((
329 row.get::<_, String>(0)?,
330 row.get::<_, String>(1)?,
331 row.get::<_, String>(2)?,
332 row.get::<_, Vec<u8>>(3)?,
333 ))
334 })
335 .map_err(|e| SynapticError::VectorStore(format!("SQLite query error: {e}")))?
336 .filter_map(|r| r.ok())
337 .map(|(id, content, meta_str, blob)| {
338 let metadata: HashMap<String, Value> =
339 serde_json::from_str(&meta_str).unwrap_or_default();
340 let embedding = blob_to_embed(&blob);
341 let score = cosine_similarity(&query_vec, &embedding);
342 (
343 Document {
344 id,
345 content,
346 metadata,
347 },
348 score,
349 )
350 })
351 .collect();
352
353 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
354 scored.truncate(k);
355
356 Ok(scored)
357 })
358 .await
359 .map_err(|e| SynapticError::VectorStore(format!("spawn_blocking error: {e}")))?
360 }
361}
362
363fn embed_to_blob(embedding: &[f32]) -> Vec<u8> {
365 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
366}
367
368fn blob_to_embed(blob: &[u8]) -> Vec<f32> {
370 blob.chunks_exact(4)
371 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
372 .collect()
373}
374
375fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
377 if a.len() != b.len() || a.is_empty() {
378 return 0.0;
379 }
380
381 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
382 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
383 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
384
385 if mag_a == 0.0 || mag_b == 0.0 {
386 return 0.0;
387 }
388
389 dot / (mag_a * mag_b)
390}