1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::fmt::Debug;
4
5use crate::error::StoreError;
6use crate::traits::{StoreLifecycle, VectorStore};
7use crate::types::{MetadataFilter, SearchResult, StoreStats, VectorEntry};
8
9#[derive(Debug)]
17pub struct SqliteVecStore {
18 pool: sqlx::SqlitePool,
19 dimensions: usize,
20 table_name: String,
21 max_capacity: Option<u64>,
22}
23
24impl SqliteVecStore {
25 pub async fn new(
27 path: &str,
28 dimensions: usize,
29 max_pool_size: Option<usize>,
30 ) -> Result<Self, StoreError> {
31 let pool_size = max_pool_size.unwrap_or(5);
32 let conn_str = if path == ":memory:" {
33 "sqlite::memory:".to_string()
34 } else {
35 format!("sqlite:{path}")
36 };
37
38 let pool = sqlx::sqlite::SqlitePoolOptions::new()
39 .max_connections(pool_size as u32)
40 .connect(&conn_str)
41 .await
42 .map_err(|e| StoreError::Database(e.to_string()))?;
43
44 sqlx::query("PRAGMA journal_mode=WAL")
46 .execute(&pool)
47 .await
48 .map_err(|e| StoreError::Database(e.to_string()))?;
49
50 Ok(Self {
51 pool,
52 dimensions,
53 table_name: "embeddings".into(),
54 max_capacity: None,
55 })
56 }
57
58 pub fn with_table_name(mut self, name: &str) -> Self {
60 self.table_name = name.to_string();
61 self
62 }
63
64 pub fn with_max_capacity(mut self, capacity: Option<u64>) -> Self {
66 self.max_capacity = capacity;
67 self
68 }
69
70 pub async fn purge_expired(&self) -> Result<usize, StoreError> {
72 let now_ms = std::time::SystemTime::now()
73 .duration_since(std::time::UNIX_EPOCH)
74 .unwrap_or_default()
75 .as_millis() as u64;
76
77 let result = sqlx::query(&format!(
78 "DELETE FROM {} WHERE expires_at IS NOT NULL AND expires_at < ?",
79 self.table_name
80 ))
81 .bind(now_ms as i64)
82 .execute(&self.pool)
83 .await
84 .map_err(|e| StoreError::Database(e.to_string()))?;
85
86 Ok(result.rows_affected() as usize)
87 }
88
89 fn build_filter_clause(filter: &MetadataFilter) -> (String, Vec<String>) {
90 match filter {
91 MetadataFilter::Eq { key, value } => (
92 format!("json_extract(metadata_json, '$.{key}') = ?"),
93 vec![value.clone()],
94 ),
95 MetadataFilter::Ne { key, value } => (
96 format!("json_extract(metadata_json, '$.{key}') != ?"),
97 vec![value.clone()],
98 ),
99 MetadataFilter::In { key, values } => {
100 let placeholders: Vec<String> = values.iter().map(|_| "?".to_string()).collect();
101 (
102 format!(
103 "json_extract(metadata_json, '$.{key}') IN ({})",
104 placeholders.join(", ")
105 ),
106 values.clone(),
107 )
108 }
109 MetadataFilter::NotIn { key, values } => {
110 let placeholders: Vec<String> = values.iter().map(|_| "?".to_string()).collect();
111 (
112 format!(
113 "json_extract(metadata_json, '$.{key}') NOT IN ({})",
114 placeholders.join(", ")
115 ),
116 values.clone(),
117 )
118 }
119 MetadataFilter::Exists { key } => (
120 format!("json_extract(metadata_json, '$.{key}') IS NOT NULL"),
121 vec![],
122 ),
123 MetadataFilter::Contains { key, value } => (
124 format!("json_extract(metadata_json, '$.{key}') LIKE ?"),
125 vec![format!("%{value}%")],
126 ),
127 MetadataFilter::Range { key, min, max } => {
128 let mut clauses = Vec::new();
129 let mut params = Vec::new();
130 if let Some(min_val) = min {
131 clauses.push(format!(
132 "CAST(json_extract(metadata_json, '$.{key}') AS REAL) >= ?"
133 ));
134 params.push(min_val.to_string());
135 }
136 if let Some(max_val) = max {
137 clauses.push(format!(
138 "CAST(json_extract(metadata_json, '$.{key}') AS REAL) <= ?"
139 ));
140 params.push(max_val.to_string());
141 }
142 (clauses.join(" AND "), params)
143 }
144 MetadataFilter::And(filters) => {
145 let mut clauses = Vec::new();
146 let mut all_params = Vec::new();
147 for f in filters {
148 let (clause, mut params) = Self::build_filter_clause(f);
149 if !clause.is_empty() {
150 clauses.push(format!("({clause})"));
151 all_params.append(&mut params);
152 }
153 }
154 (clauses.join(" AND "), all_params)
155 }
156 MetadataFilter::Or(filters) => {
157 let mut clauses = Vec::new();
158 let mut all_params = Vec::new();
159 for f in filters {
160 let (clause, mut params) = Self::build_filter_clause(f);
161 if !clause.is_empty() {
162 clauses.push(format!("({clause})"));
163 all_params.append(&mut params);
164 }
165 }
166 (clauses.join(" OR "), all_params)
167 }
168 MetadataFilter::Not(filter) => {
169 let (inner, params) = Self::build_filter_clause(filter);
170 (format!("NOT ({inner})"), params)
171 }
172 }
173 }
174
175 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
176 if a.len() != b.len() {
177 return 0.0;
178 }
179 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
180 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
181 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
182 if norm_a == 0.0 || norm_b == 0.0 {
183 return 0.0;
184 }
185 dot / (norm_a * norm_b)
186 }
187
188 fn vector_to_blob(v: &[f32]) -> Vec<u8> {
189 v.iter().flat_map(|f| f.to_le_bytes()).collect()
190 }
191
192 fn blob_to_vector(b: &[u8]) -> Vec<f32> {
193 b.chunks_exact(4)
194 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
195 .collect()
196 }
197}
198
199#[async_trait]
200impl VectorStore for SqliteVecStore {
201 async fn insert(&self, entry: VectorEntry) -> Result<(), StoreError> {
202 self.insert_batch(vec![entry]).await
203 }
204
205 async fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<(), StoreError> {
206 if entries.is_empty() {
207 return Ok(());
208 }
209
210 for entry in &entries {
212 if entry.vector.len() != self.dimensions {
213 return Err(StoreError::DimensionMismatch {
214 expected: self.dimensions,
215 actual: entry.vector.len(),
216 });
217 }
218 }
219
220 for entry in entries {
221 let vector_blob = Self::vector_to_blob(&entry.vector);
222 let metadata_json = serde_json::to_string(&entry.metadata)
223 .map_err(|e| StoreError::Serialization(e.to_string()))?;
224
225 sqlx::query(&format!(
226 "INSERT OR REPLACE INTO {} (id, content, metadata_json, channel, created_at, expires_at, embedding) VALUES (?, ?, ?, ?, ?, ?, ?)",
227 self.table_name
228 ))
229 .bind(&entry.id)
230 .bind(&entry.content)
231 .bind(&metadata_json)
232 .bind(&entry.channel)
233 .bind(entry.created_at as i64)
234 .bind(entry.expires_at.map(|t| t as i64))
235 .bind(&vector_blob)
236 .execute(&self.pool)
237 .await
238 .map_err(|e| StoreError::Database(e.to_string()))?;
239 }
240
241 Ok(())
242 }
243
244 async fn search(&self, query: &[f32], limit: usize) -> Result<Vec<SearchResult>, StoreError> {
245 if query.len() != self.dimensions {
246 return Err(StoreError::DimensionMismatch {
247 expected: self.dimensions,
248 actual: query.len(),
249 });
250 }
251
252 let rows = sqlx::query_as::<_, EmbeddingRow>(&format!(
253 "SELECT id, content, metadata_json, channel, embedding FROM {}",
254 self.table_name
255 ))
256 .fetch_all(&self.pool)
257 .await
258 .map_err(|e| StoreError::Database(e.to_string()))?;
259
260 let mut scored: Vec<(SearchResult, f32)> = rows
261 .iter()
262 .map(|row| {
263 let vector = Self::blob_to_vector(&row.embedding);
264 let similarity = Self::cosine_similarity(query, &vector);
265 let metadata: HashMap<String, String> =
266 serde_json::from_str(&row.metadata_json).unwrap_or_default();
267
268 (
269 SearchResult {
270 id: row.id.clone(),
271 score: similarity,
272 metadata,
273 content: row.content.clone(),
274 channel: row.channel.clone(),
275 },
276 similarity,
277 )
278 })
279 .collect();
280
281 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
282 scored.truncate(limit);
283
284 Ok(scored.into_iter().map(|(r, _)| r).collect())
285 }
286
287 async fn search_with_filter(
288 &self,
289 query: &[f32],
290 filter: &MetadataFilter,
291 limit: usize,
292 ) -> Result<Vec<SearchResult>, StoreError> {
293 if query.len() != self.dimensions {
294 return Err(StoreError::DimensionMismatch {
295 expected: self.dimensions,
296 actual: query.len(),
297 });
298 }
299
300 let (filter_clause, params) = Self::build_filter_clause(filter);
301 let sql = format!(
302 "SELECT id, content, metadata_json, channel, embedding FROM {} WHERE {}",
303 self.table_name, filter_clause
304 );
305
306 let mut query_builder = sqlx::query_as::<_, EmbeddingRow>(&sql);
307 for param in ¶ms {
308 query_builder = query_builder.bind(param);
309 }
310
311 let rows = query_builder
312 .fetch_all(&self.pool)
313 .await
314 .map_err(|e| StoreError::Database(e.to_string()))?;
315
316 let mut scored: Vec<(SearchResult, f32)> = rows
317 .iter()
318 .map(|row| {
319 let vector = Self::blob_to_vector(&row.embedding);
320 let similarity = Self::cosine_similarity(query, &vector);
321 let metadata: HashMap<String, String> =
322 serde_json::from_str(&row.metadata_json).unwrap_or_default();
323
324 (
325 SearchResult {
326 id: row.id.clone(),
327 score: similarity,
328 metadata,
329 content: row.content.clone(),
330 channel: row.channel.clone(),
331 },
332 similarity,
333 )
334 })
335 .collect();
336
337 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
338 scored.truncate(limit);
339
340 Ok(scored.into_iter().map(|(r, _)| r).collect())
341 }
342
343 async fn delete(&self, ids: &[String]) -> Result<usize, StoreError> {
344 if ids.is_empty() {
345 return Ok(0);
346 }
347 let placeholders: Vec<String> = ids.iter().map(|_| "?".to_string()).collect();
348 let sql = format!(
349 "DELETE FROM {} WHERE id IN ({})",
350 self.table_name,
351 placeholders.join(", ")
352 );
353
354 let mut query = sqlx::query(&sql);
355 for id in ids {
356 query = query.bind(id);
357 }
358
359 let result = query
360 .execute(&self.pool)
361 .await
362 .map_err(|e| StoreError::Database(e.to_string()))?;
363
364 Ok(result.rows_affected() as usize)
365 }
366
367 async fn delete_by_filter(&self, filter: &MetadataFilter) -> Result<usize, StoreError> {
368 let (filter_clause, params) = Self::build_filter_clause(filter);
369 let sql = format!("DELETE FROM {} WHERE {}", self.table_name, filter_clause);
370
371 let mut query = sqlx::query(&sql);
372 for param in ¶ms {
373 query = query.bind(param);
374 }
375
376 let result = query
377 .execute(&self.pool)
378 .await
379 .map_err(|e| StoreError::Database(e.to_string()))?;
380
381 Ok(result.rows_affected() as usize)
382 }
383
384 async fn clear(&self) -> Result<(), StoreError> {
385 sqlx::query(&format!("DELETE FROM {}", self.table_name))
386 .execute(&self.pool)
387 .await
388 .map_err(|e| StoreError::Database(e.to_string()))?;
389 Ok(())
390 }
391
392 async fn count(&self) -> Result<usize, StoreError> {
393 let (count,): (i64,) = sqlx::query_as(&format!(
394 "SELECT COUNT(*) FROM {}",
395 self.table_name
396 ))
397 .fetch_one(&self.pool)
398 .await
399 .map_err(|e| StoreError::Database(e.to_string()))?;
400
401 Ok(count as usize)
402 }
403
404 async fn rebuild_index(&self) -> Result<(), StoreError> {
405 sqlx::query("PRAGMA wal_checkpoint(FULL)")
407 .execute(&self.pool)
408 .await
409 .map_err(|e| StoreError::Database(e.to_string()))?;
410 Ok(())
411 }
412
413 async fn stats(&self) -> Result<StoreStats, StoreError> {
414 let count = self.count().await?;
415 Ok(StoreStats {
416 total_vectors: count,
417 total_dimensions: self.dimensions,
418 index_size_bytes: 0,
419 data_size_bytes: 0,
420 last_indexed_at: None,
421 })
422 }
423}
424
425#[async_trait]
426impl StoreLifecycle for SqliteVecStore {
427 async fn initialize(&self) -> Result<(), StoreError> {
428 sqlx::query(&format!(
429 "CREATE TABLE IF NOT EXISTS {} (
430 id TEXT PRIMARY KEY,
431 content TEXT,
432 metadata_json TEXT,
433 channel TEXT,
434 created_at INTEGER NOT NULL,
435 expires_at INTEGER,
436 embedding BLOB NOT NULL
437 )",
438 self.table_name
439 ))
440 .execute(&self.pool)
441 .await
442 .map_err(|e| StoreError::Database(e.to_string()))?;
443
444 Ok(())
445 }
446
447 async fn close(&self) -> Result<(), StoreError> {
448 self.pool.close().await;
449 Ok(())
450 }
451
452 async fn checkpoint(&self) -> Result<(), StoreError> {
453 sqlx::query("PRAGMA wal_checkpoint(FULL)")
454 .execute(&self.pool)
455 .await
456 .map_err(|e| StoreError::Database(e.to_string()))?;
457 Ok(())
458 }
459
460 async fn health_check(&self) -> Result<bool, StoreError> {
461 sqlx::query("SELECT 1")
462 .execute(&self.pool)
463 .await
464 .map_err(|e| StoreError::Database(e.to_string()))?;
465 Ok(true)
466 }
467}
468
469#[derive(Debug, sqlx::FromRow)]
470struct EmbeddingRow {
471 id: String,
472 content: Option<String>,
473 metadata_json: String,
474 channel: Option<String>,
475 embedding: Vec<u8>,
476}