1use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7
8use sqlx::SqlitePool;
9
10use crate::vector_store::{
11 FieldValue, ScoredVectorPoint, ScrollResult, VectorFilter, VectorPoint, VectorStore,
12 VectorStoreError,
13};
14
15pub struct SqliteVectorStore {
16 pool: SqlitePool,
17}
18
19impl SqliteVectorStore {
20 #[must_use]
21 pub fn new(pool: SqlitePool) -> Self {
22 Self { pool }
23 }
24}
25
26use crate::math::cosine_similarity;
27
28fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
29 for cond in &filter.must {
30 let Some(val) = payload.get(&cond.field) else {
31 return false;
32 };
33 let matches = match &cond.value {
34 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
35 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
36 };
37 if !matches {
38 return false;
39 }
40 }
41 for cond in &filter.must_not {
42 let Some(val) = payload.get(&cond.field) else {
43 continue;
44 };
45 let matches = match &cond.value {
46 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
47 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
48 };
49 if matches {
50 return false;
51 }
52 }
53 true
54}
55
56type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
57
58impl VectorStore for SqliteVectorStore {
59 fn ensure_collection(
60 &self,
61 collection: &str,
62 _vector_size: u64,
63 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
64 let collection = collection.to_owned();
65 Box::pin(async move {
66 sqlx::query("INSERT OR IGNORE INTO vector_collections (name) VALUES (?)")
67 .bind(&collection)
68 .execute(&self.pool)
69 .await
70 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
71 Ok(())
72 })
73 }
74
75 fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
76 let collection = collection.to_owned();
77 Box::pin(async move {
78 let row: (i64,) =
79 sqlx::query_as("SELECT COUNT(*) FROM vector_collections WHERE name = ?")
80 .bind(&collection)
81 .fetch_one(&self.pool)
82 .await
83 .map_err(|e| VectorStoreError::Connection(e.to_string()))?;
84 Ok(row.0 > 0)
85 })
86 }
87
88 fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
89 let collection = collection.to_owned();
90 Box::pin(async move {
91 sqlx::query("DELETE FROM vector_points WHERE collection = ?")
92 .bind(&collection)
93 .execute(&self.pool)
94 .await
95 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
96 sqlx::query("DELETE FROM vector_collections WHERE name = ?")
97 .bind(&collection)
98 .execute(&self.pool)
99 .await
100 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
101 Ok(())
102 })
103 }
104
105 fn upsert(
106 &self,
107 collection: &str,
108 points: Vec<VectorPoint>,
109 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
110 let collection = collection.to_owned();
111 Box::pin(async move {
112 for point in points {
113 let vector_bytes: Vec<u8> = bytemuck::cast_slice(&point.vector).to_vec();
114 let payload_json = serde_json::to_string(&point.payload)
115 .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
116 sqlx::query(
117 "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?) \
118 ON CONFLICT(collection, id) DO UPDATE SET vector = excluded.vector, payload = excluded.payload",
119 )
120 .bind(&point.id)
121 .bind(&collection)
122 .bind(&vector_bytes)
123 .bind(&payload_json)
124 .execute(&self.pool)
125 .await
126 .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
127 }
128 Ok(())
129 })
130 }
131
132 fn search(
133 &self,
134 collection: &str,
135 vector: Vec<f32>,
136 limit: u64,
137 filter: Option<VectorFilter>,
138 ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
139 let collection = collection.to_owned();
140 Box::pin(async move {
141 let rows: Vec<(String, Vec<u8>, String)> = sqlx::query_as(
142 "SELECT id, vector, payload FROM vector_points WHERE collection = ?",
143 )
144 .bind(&collection)
145 .fetch_all(&self.pool)
146 .await
147 .map_err(|e| VectorStoreError::Search(e.to_string()))?;
148
149 let limit_usize = usize::try_from(limit).unwrap_or(usize::MAX);
150 let mut scored: Vec<ScoredVectorPoint> = rows
151 .into_iter()
152 .filter_map(|(id, blob, payload_str)| {
153 let Ok(stored) = bytemuck::try_cast_slice::<u8, f32>(&blob) else {
154 return None;
155 };
156 let payload: HashMap<String, serde_json::Value> =
157 serde_json::from_str(&payload_str).unwrap_or_default();
158
159 if filter
160 .as_ref()
161 .is_some_and(|f| !matches_filter(&payload, f))
162 {
163 return None;
164 }
165
166 let score = cosine_similarity(&vector, stored);
167 Some(ScoredVectorPoint { id, score, payload })
168 })
169 .collect();
170
171 scored.sort_by(|a, b| {
172 b.score
173 .partial_cmp(&a.score)
174 .unwrap_or(std::cmp::Ordering::Equal)
175 });
176 scored.truncate(limit_usize);
177 Ok(scored)
178 })
179 }
180
181 fn delete_by_ids(
182 &self,
183 collection: &str,
184 ids: Vec<String>,
185 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
186 let collection = collection.to_owned();
187 Box::pin(async move {
188 for id in ids {
189 sqlx::query("DELETE FROM vector_points WHERE collection = ? AND id = ?")
190 .bind(&collection)
191 .bind(&id)
192 .execute(&self.pool)
193 .await
194 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
195 }
196 Ok(())
197 })
198 }
199
200 fn scroll_all(
201 &self,
202 collection: &str,
203 key_field: &str,
204 ) -> BoxFuture<'_, Result<ScrollResult, VectorStoreError>> {
205 let collection = collection.to_owned();
206 let key_field = key_field.to_owned();
207 Box::pin(async move {
208 let rows: Vec<(String, String)> =
209 sqlx::query_as("SELECT id, payload FROM vector_points WHERE collection = ?")
210 .bind(&collection)
211 .fetch_all(&self.pool)
212 .await
213 .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
214
215 let mut result = ScrollResult::new();
216 for (id, payload_str) in rows {
217 let payload: HashMap<String, serde_json::Value> =
218 serde_json::from_str(&payload_str).unwrap_or_default();
219 if let Some(val) = payload.get(&key_field) {
220 let mut map = HashMap::new();
221 map.insert(
222 key_field.clone(),
223 val.as_str().unwrap_or_default().to_owned(),
224 );
225 result.insert(id, map);
226 }
227 }
228 Ok(result)
229 })
230 }
231
232 fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
233 Box::pin(async move {
234 sqlx::query_scalar::<_, i32>("SELECT 1")
235 .fetch_one(&self.pool)
236 .await
237 .map(|_| true)
238 .map_err(|e| VectorStoreError::Collection(e.to_string()))
239 })
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use crate::sqlite::SqliteStore;
247 use crate::vector_store::FieldCondition;
248
249 async fn setup() -> (SqliteVectorStore, SqliteStore) {
250 let store = SqliteStore::new(":memory:").await.unwrap();
251 let pool = store.pool().clone();
252 let vs = SqliteVectorStore::new(pool);
253 (vs, store)
254 }
255
256 #[tokio::test]
257 async fn ensure_and_exists() {
258 let (vs, _) = setup().await;
259 assert!(!vs.collection_exists("col1").await.unwrap());
260 vs.ensure_collection("col1", 4).await.unwrap();
261 assert!(vs.collection_exists("col1").await.unwrap());
262 vs.ensure_collection("col1", 4).await.unwrap();
264 assert!(vs.collection_exists("col1").await.unwrap());
265 }
266
267 #[tokio::test]
268 async fn delete_collection() {
269 let (vs, _) = setup().await;
270 vs.ensure_collection("col1", 4).await.unwrap();
271 vs.upsert(
272 "col1",
273 vec![VectorPoint {
274 id: "p1".into(),
275 vector: vec![1.0, 0.0, 0.0, 0.0],
276 payload: HashMap::new(),
277 }],
278 )
279 .await
280 .unwrap();
281 vs.delete_collection("col1").await.unwrap();
282 assert!(!vs.collection_exists("col1").await.unwrap());
283 }
284
285 #[tokio::test]
286 async fn upsert_and_search() {
287 let (vs, _) = setup().await;
288 vs.ensure_collection("c", 4).await.unwrap();
289 vs.upsert(
290 "c",
291 vec![
292 VectorPoint {
293 id: "a".into(),
294 vector: vec![1.0, 0.0, 0.0, 0.0],
295 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
296 },
297 VectorPoint {
298 id: "b".into(),
299 vector: vec![0.0, 1.0, 0.0, 0.0],
300 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
301 },
302 ],
303 )
304 .await
305 .unwrap();
306
307 let results = vs
308 .search("c", vec![1.0, 0.0, 0.0, 0.0], 2, None)
309 .await
310 .unwrap();
311 assert_eq!(results.len(), 2);
312 assert_eq!(results[0].id, "a");
313 assert!((results[0].score - 1.0).abs() < 1e-5);
314 }
315
316 #[tokio::test]
317 async fn search_with_filter() {
318 let (vs, _) = setup().await;
319 vs.ensure_collection("c", 4).await.unwrap();
320 vs.upsert(
321 "c",
322 vec![
323 VectorPoint {
324 id: "a".into(),
325 vector: vec![1.0, 0.0, 0.0, 0.0],
326 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
327 },
328 VectorPoint {
329 id: "b".into(),
330 vector: vec![1.0, 0.0, 0.0, 0.0],
331 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
332 },
333 ],
334 )
335 .await
336 .unwrap();
337
338 let filter = VectorFilter {
339 must: vec![FieldCondition {
340 field: "role".into(),
341 value: FieldValue::Text("user".into()),
342 }],
343 must_not: vec![],
344 };
345 let results = vs
346 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
347 .await
348 .unwrap();
349 assert_eq!(results.len(), 1);
350 assert_eq!(results[0].id, "a");
351 }
352
353 #[tokio::test]
354 async fn delete_by_ids() {
355 let (vs, _) = setup().await;
356 vs.ensure_collection("c", 4).await.unwrap();
357 vs.upsert(
358 "c",
359 vec![
360 VectorPoint {
361 id: "a".into(),
362 vector: vec![1.0, 0.0, 0.0, 0.0],
363 payload: HashMap::new(),
364 },
365 VectorPoint {
366 id: "b".into(),
367 vector: vec![0.0, 1.0, 0.0, 0.0],
368 payload: HashMap::new(),
369 },
370 ],
371 )
372 .await
373 .unwrap();
374 vs.delete_by_ids("c", vec!["a".into()]).await.unwrap();
375 let results = vs
376 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
377 .await
378 .unwrap();
379 assert_eq!(results.len(), 1);
380 assert_eq!(results[0].id, "b");
381 }
382
383 #[tokio::test]
384 async fn scroll_all() {
385 let (vs, _) = setup().await;
386 vs.ensure_collection("c", 4).await.unwrap();
387 vs.upsert(
388 "c",
389 vec![VectorPoint {
390 id: "p1".into(),
391 vector: vec![1.0, 0.0, 0.0, 0.0],
392 payload: HashMap::from([("text".into(), serde_json::json!("hello"))]),
393 }],
394 )
395 .await
396 .unwrap();
397 let result = vs.scroll_all("c", "text").await.unwrap();
398 assert_eq!(result.len(), 1);
399 assert_eq!(result["p1"]["text"], "hello");
400 }
401
402 #[tokio::test]
403 async fn upsert_updates_existing() {
404 let (vs, _) = setup().await;
405 vs.ensure_collection("c", 4).await.unwrap();
406 vs.upsert(
407 "c",
408 vec![VectorPoint {
409 id: "p1".into(),
410 vector: vec![1.0, 0.0, 0.0, 0.0],
411 payload: HashMap::from([("v".into(), serde_json::json!(1))]),
412 }],
413 )
414 .await
415 .unwrap();
416 vs.upsert(
417 "c",
418 vec![VectorPoint {
419 id: "p1".into(),
420 vector: vec![0.0, 1.0, 0.0, 0.0],
421 payload: HashMap::from([("v".into(), serde_json::json!(2))]),
422 }],
423 )
424 .await
425 .unwrap();
426 let results = vs
427 .search("c", vec![0.0, 1.0, 0.0, 0.0], 1, None)
428 .await
429 .unwrap();
430 assert_eq!(results.len(), 1);
431 assert!((results[0].score - 1.0).abs() < 1e-5);
432 }
433
434 #[test]
435 fn cosine_similarity_import_wired() {
436 assert!(!cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).is_nan());
438 }
439
440 #[tokio::test]
441 async fn search_with_must_not_filter() {
442 let (vs, _) = setup().await;
443 vs.ensure_collection("c", 4).await.unwrap();
444 vs.upsert(
445 "c",
446 vec![
447 VectorPoint {
448 id: "a".into(),
449 vector: vec![1.0, 0.0, 0.0, 0.0],
450 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
451 },
452 VectorPoint {
453 id: "b".into(),
454 vector: vec![1.0, 0.0, 0.0, 0.0],
455 payload: HashMap::from([("role".into(), serde_json::json!("system"))]),
456 },
457 ],
458 )
459 .await
460 .unwrap();
461
462 let filter = VectorFilter {
463 must: vec![],
464 must_not: vec![FieldCondition {
465 field: "role".into(),
466 value: FieldValue::Text("system".into()),
467 }],
468 };
469 let results = vs
470 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
471 .await
472 .unwrap();
473 assert_eq!(results.len(), 1);
474 assert_eq!(results[0].id, "a");
475 }
476
477 #[tokio::test]
478 async fn search_with_integer_filter() {
479 let (vs, _) = setup().await;
480 vs.ensure_collection("c", 4).await.unwrap();
481 vs.upsert(
482 "c",
483 vec![
484 VectorPoint {
485 id: "a".into(),
486 vector: vec![1.0, 0.0, 0.0, 0.0],
487 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
488 },
489 VectorPoint {
490 id: "b".into(),
491 vector: vec![1.0, 0.0, 0.0, 0.0],
492 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
493 },
494 ],
495 )
496 .await
497 .unwrap();
498
499 let filter = VectorFilter {
500 must: vec![FieldCondition {
501 field: "conv_id".into(),
502 value: FieldValue::Integer(1),
503 }],
504 must_not: vec![],
505 };
506 let results = vs
507 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
508 .await
509 .unwrap();
510 assert_eq!(results.len(), 1);
511 assert_eq!(results[0].id, "a");
512 }
513
514 #[tokio::test]
515 async fn search_empty_collection() {
516 let (vs, _) = setup().await;
517 vs.ensure_collection("c", 4).await.unwrap();
518 let results = vs
519 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
520 .await
521 .unwrap();
522 assert!(results.is_empty());
523 }
524
525 #[tokio::test]
526 async fn search_with_must_not_integer_filter() {
527 let (vs, _) = setup().await;
528 vs.ensure_collection("c", 4).await.unwrap();
529 vs.upsert(
530 "c",
531 vec![
532 VectorPoint {
533 id: "a".into(),
534 vector: vec![1.0, 0.0, 0.0, 0.0],
535 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
536 },
537 VectorPoint {
538 id: "b".into(),
539 vector: vec![1.0, 0.0, 0.0, 0.0],
540 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
541 },
542 ],
543 )
544 .await
545 .unwrap();
546
547 let filter = VectorFilter {
548 must: vec![],
549 must_not: vec![FieldCondition {
550 field: "conv_id".into(),
551 value: FieldValue::Integer(1),
552 }],
553 };
554 let results = vs
555 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
556 .await
557 .unwrap();
558 assert_eq!(results.len(), 1);
559 assert_eq!(results[0].id, "b");
560 }
561
562 #[tokio::test]
563 async fn search_with_combined_must_and_must_not() {
564 let (vs, _) = setup().await;
565 vs.ensure_collection("c", 4).await.unwrap();
566 vs.upsert(
567 "c",
568 vec![
569 VectorPoint {
570 id: "a".into(),
571 vector: vec![1.0, 0.0, 0.0, 0.0],
572 payload: HashMap::from([
573 ("role".into(), serde_json::json!("user")),
574 ("conv_id".into(), serde_json::json!(1)),
575 ]),
576 },
577 VectorPoint {
578 id: "b".into(),
579 vector: vec![1.0, 0.0, 0.0, 0.0],
580 payload: HashMap::from([
581 ("role".into(), serde_json::json!("user")),
582 ("conv_id".into(), serde_json::json!(2)),
583 ]),
584 },
585 VectorPoint {
586 id: "c".into(),
587 vector: vec![1.0, 0.0, 0.0, 0.0],
588 payload: HashMap::from([
589 ("role".into(), serde_json::json!("assistant")),
590 ("conv_id".into(), serde_json::json!(1)),
591 ]),
592 },
593 ],
594 )
595 .await
596 .unwrap();
597
598 let filter = VectorFilter {
599 must: vec![FieldCondition {
600 field: "role".into(),
601 value: FieldValue::Text("user".into()),
602 }],
603 must_not: vec![FieldCondition {
604 field: "conv_id".into(),
605 value: FieldValue::Integer(2),
606 }],
607 };
608 let results = vs
609 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
610 .await
611 .unwrap();
612 assert_eq!(results.len(), 1);
614 assert_eq!(results[0].id, "a");
615 }
616
617 #[tokio::test]
618 async fn scroll_all_missing_key_field() {
619 let (vs, _) = setup().await;
620 vs.ensure_collection("c", 4).await.unwrap();
621 vs.upsert(
622 "c",
623 vec![VectorPoint {
624 id: "p1".into(),
625 vector: vec![1.0, 0.0, 0.0, 0.0],
626 payload: HashMap::from([("other".into(), serde_json::json!("value"))]),
627 }],
628 )
629 .await
630 .unwrap();
631 let result = vs.scroll_all("c", "text").await.unwrap();
633 assert!(
634 result.is_empty(),
635 "points without the key field must not appear in scroll result"
636 );
637 }
638
639 #[tokio::test]
640 async fn delete_by_ids_empty_and_nonexistent() {
641 let (vs, _) = setup().await;
642 vs.ensure_collection("c", 4).await.unwrap();
643 vs.upsert(
644 "c",
645 vec![VectorPoint {
646 id: "a".into(),
647 vector: vec![1.0, 0.0, 0.0, 0.0],
648 payload: HashMap::new(),
649 }],
650 )
651 .await
652 .unwrap();
653
654 vs.delete_by_ids("c", vec![]).await.unwrap();
656
657 vs.delete_by_ids("c", vec!["nonexistent".into()])
659 .await
660 .unwrap();
661
662 let results = vs
664 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
665 .await
666 .unwrap();
667 assert_eq!(results.len(), 1);
668 assert_eq!(results[0].id, "a");
669 }
670
671 #[tokio::test]
672 async fn search_corrupt_blob_skipped() {
673 let (vs, store) = setup().await;
674 vs.ensure_collection("c", 4).await.unwrap();
675
676 vs.upsert(
678 "c",
679 vec![VectorPoint {
680 id: "valid".into(),
681 vector: vec![1.0, 0.0, 0.0, 0.0],
682 payload: HashMap::new(),
683 }],
684 )
685 .await
686 .unwrap();
687
688 let corrupt_blob: Vec<u8> = vec![0xFF, 0xFE, 0xFD];
691 let payload_json = r"{}";
692 sqlx::query(
693 "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?)",
694 )
695 .bind("corrupt")
696 .bind("c")
697 .bind(&corrupt_blob)
698 .bind(payload_json)
699 .execute(store.pool())
700 .await
701 .unwrap();
702
703 let results = vs
705 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
706 .await
707 .unwrap();
708 assert_eq!(results.len(), 1);
709 assert_eq!(results[0].id, "valid");
710 }
711}