1use std::collections::HashMap;
5
6use sqlx::SqlitePool;
7
8use crate::vector_store::{
9 BoxFuture, FieldValue, ScoredVectorPoint, ScrollResult, VectorFilter, VectorPoint, VectorStore,
10 VectorStoreError,
11};
12
13pub struct SqliteVectorStore {
14 pool: SqlitePool,
15}
16
17impl SqliteVectorStore {
18 #[must_use]
19 pub fn new(pool: SqlitePool) -> Self {
20 Self { pool }
21 }
22}
23
24use crate::math::cosine_similarity;
25
26fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
27 for cond in &filter.must {
28 let Some(val) = payload.get(&cond.field) else {
29 return false;
30 };
31 let matches = match &cond.value {
32 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
33 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
34 };
35 if !matches {
36 return false;
37 }
38 }
39 for cond in &filter.must_not {
40 let Some(val) = payload.get(&cond.field) else {
41 continue;
42 };
43 let matches = match &cond.value {
44 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
45 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
46 };
47 if matches {
48 return false;
49 }
50 }
51 true
52}
53
54impl VectorStore for SqliteVectorStore {
55 fn ensure_collection(
56 &self,
57 collection: &str,
58 _vector_size: u64,
59 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
60 let collection = collection.to_owned();
61 Box::pin(async move {
62 sqlx::query("INSERT OR IGNORE INTO vector_collections (name) VALUES (?)")
63 .bind(&collection)
64 .execute(&self.pool)
65 .await
66 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
67 Ok(())
68 })
69 }
70
71 fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
72 let collection = collection.to_owned();
73 Box::pin(async move {
74 let row: (i64,) =
75 sqlx::query_as("SELECT COUNT(*) FROM vector_collections WHERE name = ?")
76 .bind(&collection)
77 .fetch_one(&self.pool)
78 .await
79 .map_err(|e| VectorStoreError::Connection(e.to_string()))?;
80 Ok(row.0 > 0)
81 })
82 }
83
84 fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
85 let collection = collection.to_owned();
86 Box::pin(async move {
87 sqlx::query("DELETE FROM vector_points WHERE collection = ?")
88 .bind(&collection)
89 .execute(&self.pool)
90 .await
91 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
92 sqlx::query("DELETE FROM vector_collections WHERE name = ?")
93 .bind(&collection)
94 .execute(&self.pool)
95 .await
96 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
97 Ok(())
98 })
99 }
100
101 fn upsert(
102 &self,
103 collection: &str,
104 points: Vec<VectorPoint>,
105 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
106 let collection = collection.to_owned();
107 Box::pin(async move {
108 for point in points {
109 let vector_bytes: Vec<u8> = bytemuck::cast_slice(&point.vector).to_vec();
110 let payload_json = serde_json::to_string(&point.payload)
111 .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
112 sqlx::query(
113 "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?) \
114 ON CONFLICT(collection, id) DO UPDATE SET vector = excluded.vector, payload = excluded.payload",
115 )
116 .bind(&point.id)
117 .bind(&collection)
118 .bind(&vector_bytes)
119 .bind(&payload_json)
120 .execute(&self.pool)
121 .await
122 .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
123 }
124 Ok(())
125 })
126 }
127
128 fn search(
129 &self,
130 collection: &str,
131 vector: Vec<f32>,
132 limit: u64,
133 filter: Option<VectorFilter>,
134 ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
135 let collection = collection.to_owned();
136 Box::pin(async move {
137 let rows: Vec<(String, Vec<u8>, String)> = sqlx::query_as(
138 "SELECT id, vector, payload FROM vector_points WHERE collection = ?",
139 )
140 .bind(&collection)
141 .fetch_all(&self.pool)
142 .await
143 .map_err(|e| VectorStoreError::Search(e.to_string()))?;
144
145 let limit_usize = usize::try_from(limit).unwrap_or(usize::MAX);
146 let mut scored: Vec<ScoredVectorPoint> = rows
147 .into_iter()
148 .filter_map(|(id, blob, payload_str)| {
149 let Ok(stored) = bytemuck::try_cast_slice::<u8, f32>(&blob) else {
150 return None;
151 };
152 let payload: HashMap<String, serde_json::Value> =
153 serde_json::from_str(&payload_str).unwrap_or_default();
154
155 if filter
156 .as_ref()
157 .is_some_and(|f| !matches_filter(&payload, f))
158 {
159 return None;
160 }
161
162 let score = cosine_similarity(&vector, stored);
163 Some(ScoredVectorPoint { id, score, payload })
164 })
165 .collect();
166
167 scored.sort_by(|a, b| {
168 b.score
169 .partial_cmp(&a.score)
170 .unwrap_or(std::cmp::Ordering::Equal)
171 });
172 scored.truncate(limit_usize);
173 Ok(scored)
174 })
175 }
176
177 fn delete_by_ids(
178 &self,
179 collection: &str,
180 ids: Vec<String>,
181 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
182 let collection = collection.to_owned();
183 Box::pin(async move {
184 for id in ids {
185 sqlx::query("DELETE FROM vector_points WHERE collection = ? AND id = ?")
186 .bind(&collection)
187 .bind(&id)
188 .execute(&self.pool)
189 .await
190 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
191 }
192 Ok(())
193 })
194 }
195
196 fn scroll_all(
197 &self,
198 collection: &str,
199 key_field: &str,
200 ) -> BoxFuture<'_, Result<ScrollResult, VectorStoreError>> {
201 let collection = collection.to_owned();
202 let key_field = key_field.to_owned();
203 Box::pin(async move {
204 let rows: Vec<(String, String)> =
205 sqlx::query_as("SELECT id, payload FROM vector_points WHERE collection = ?")
206 .bind(&collection)
207 .fetch_all(&self.pool)
208 .await
209 .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
210
211 let mut result = ScrollResult::new();
212 for (id, payload_str) in rows {
213 let payload: HashMap<String, serde_json::Value> =
214 serde_json::from_str(&payload_str).unwrap_or_default();
215 if let Some(val) = payload.get(&key_field) {
216 let mut map = HashMap::new();
217 map.insert(
218 key_field.clone(),
219 val.as_str().unwrap_or_default().to_owned(),
220 );
221 result.insert(id, map);
222 }
223 }
224 Ok(result)
225 })
226 }
227
228 fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
229 Box::pin(async move {
230 sqlx::query_scalar::<_, i32>("SELECT 1")
231 .fetch_one(&self.pool)
232 .await
233 .map(|_| true)
234 .map_err(|e| VectorStoreError::Collection(e.to_string()))
235 })
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::sqlite::SqliteStore;
243 use crate::vector_store::FieldCondition;
244
245 async fn setup() -> (SqliteVectorStore, SqliteStore) {
246 let store = SqliteStore::new(":memory:").await.unwrap();
247 let pool = store.pool().clone();
248 let vs = SqliteVectorStore::new(pool);
249 (vs, store)
250 }
251
252 #[tokio::test]
253 async fn ensure_and_exists() {
254 let (vs, _) = setup().await;
255 assert!(!vs.collection_exists("col1").await.unwrap());
256 vs.ensure_collection("col1", 4).await.unwrap();
257 assert!(vs.collection_exists("col1").await.unwrap());
258 vs.ensure_collection("col1", 4).await.unwrap();
260 assert!(vs.collection_exists("col1").await.unwrap());
261 }
262
263 #[tokio::test]
264 async fn delete_collection() {
265 let (vs, _) = setup().await;
266 vs.ensure_collection("col1", 4).await.unwrap();
267 vs.upsert(
268 "col1",
269 vec![VectorPoint {
270 id: "p1".into(),
271 vector: vec![1.0, 0.0, 0.0, 0.0],
272 payload: HashMap::new(),
273 }],
274 )
275 .await
276 .unwrap();
277 vs.delete_collection("col1").await.unwrap();
278 assert!(!vs.collection_exists("col1").await.unwrap());
279 }
280
281 #[tokio::test]
282 async fn upsert_and_search() {
283 let (vs, _) = setup().await;
284 vs.ensure_collection("c", 4).await.unwrap();
285 vs.upsert(
286 "c",
287 vec![
288 VectorPoint {
289 id: "a".into(),
290 vector: vec![1.0, 0.0, 0.0, 0.0],
291 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
292 },
293 VectorPoint {
294 id: "b".into(),
295 vector: vec![0.0, 1.0, 0.0, 0.0],
296 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
297 },
298 ],
299 )
300 .await
301 .unwrap();
302
303 let results = vs
304 .search("c", vec![1.0, 0.0, 0.0, 0.0], 2, None)
305 .await
306 .unwrap();
307 assert_eq!(results.len(), 2);
308 assert_eq!(results[0].id, "a");
309 assert!((results[0].score - 1.0).abs() < 1e-5);
310 }
311
312 #[tokio::test]
313 async fn search_with_filter() {
314 let (vs, _) = setup().await;
315 vs.ensure_collection("c", 4).await.unwrap();
316 vs.upsert(
317 "c",
318 vec![
319 VectorPoint {
320 id: "a".into(),
321 vector: vec![1.0, 0.0, 0.0, 0.0],
322 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
323 },
324 VectorPoint {
325 id: "b".into(),
326 vector: vec![1.0, 0.0, 0.0, 0.0],
327 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
328 },
329 ],
330 )
331 .await
332 .unwrap();
333
334 let filter = VectorFilter {
335 must: vec![FieldCondition {
336 field: "role".into(),
337 value: FieldValue::Text("user".into()),
338 }],
339 must_not: vec![],
340 };
341 let results = vs
342 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
343 .await
344 .unwrap();
345 assert_eq!(results.len(), 1);
346 assert_eq!(results[0].id, "a");
347 }
348
349 #[tokio::test]
350 async fn delete_by_ids() {
351 let (vs, _) = setup().await;
352 vs.ensure_collection("c", 4).await.unwrap();
353 vs.upsert(
354 "c",
355 vec![
356 VectorPoint {
357 id: "a".into(),
358 vector: vec![1.0, 0.0, 0.0, 0.0],
359 payload: HashMap::new(),
360 },
361 VectorPoint {
362 id: "b".into(),
363 vector: vec![0.0, 1.0, 0.0, 0.0],
364 payload: HashMap::new(),
365 },
366 ],
367 )
368 .await
369 .unwrap();
370 vs.delete_by_ids("c", vec!["a".into()]).await.unwrap();
371 let results = vs
372 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
373 .await
374 .unwrap();
375 assert_eq!(results.len(), 1);
376 assert_eq!(results[0].id, "b");
377 }
378
379 #[tokio::test]
380 async fn scroll_all() {
381 let (vs, _) = setup().await;
382 vs.ensure_collection("c", 4).await.unwrap();
383 vs.upsert(
384 "c",
385 vec![VectorPoint {
386 id: "p1".into(),
387 vector: vec![1.0, 0.0, 0.0, 0.0],
388 payload: HashMap::from([("text".into(), serde_json::json!("hello"))]),
389 }],
390 )
391 .await
392 .unwrap();
393 let result = vs.scroll_all("c", "text").await.unwrap();
394 assert_eq!(result.len(), 1);
395 assert_eq!(result["p1"]["text"], "hello");
396 }
397
398 #[tokio::test]
399 async fn upsert_updates_existing() {
400 let (vs, _) = setup().await;
401 vs.ensure_collection("c", 4).await.unwrap();
402 vs.upsert(
403 "c",
404 vec![VectorPoint {
405 id: "p1".into(),
406 vector: vec![1.0, 0.0, 0.0, 0.0],
407 payload: HashMap::from([("v".into(), serde_json::json!(1))]),
408 }],
409 )
410 .await
411 .unwrap();
412 vs.upsert(
413 "c",
414 vec![VectorPoint {
415 id: "p1".into(),
416 vector: vec![0.0, 1.0, 0.0, 0.0],
417 payload: HashMap::from([("v".into(), serde_json::json!(2))]),
418 }],
419 )
420 .await
421 .unwrap();
422 let results = vs
423 .search("c", vec![0.0, 1.0, 0.0, 0.0], 1, None)
424 .await
425 .unwrap();
426 assert_eq!(results.len(), 1);
427 assert!((results[0].score - 1.0).abs() < 1e-5);
428 }
429
430 #[test]
431 fn cosine_similarity_import_wired() {
432 assert!(!cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).is_nan());
434 }
435
436 #[tokio::test]
437 async fn search_with_must_not_filter() {
438 let (vs, _) = setup().await;
439 vs.ensure_collection("c", 4).await.unwrap();
440 vs.upsert(
441 "c",
442 vec![
443 VectorPoint {
444 id: "a".into(),
445 vector: vec![1.0, 0.0, 0.0, 0.0],
446 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
447 },
448 VectorPoint {
449 id: "b".into(),
450 vector: vec![1.0, 0.0, 0.0, 0.0],
451 payload: HashMap::from([("role".into(), serde_json::json!("system"))]),
452 },
453 ],
454 )
455 .await
456 .unwrap();
457
458 let filter = VectorFilter {
459 must: vec![],
460 must_not: vec![FieldCondition {
461 field: "role".into(),
462 value: FieldValue::Text("system".into()),
463 }],
464 };
465 let results = vs
466 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
467 .await
468 .unwrap();
469 assert_eq!(results.len(), 1);
470 assert_eq!(results[0].id, "a");
471 }
472
473 #[tokio::test]
474 async fn search_with_integer_filter() {
475 let (vs, _) = setup().await;
476 vs.ensure_collection("c", 4).await.unwrap();
477 vs.upsert(
478 "c",
479 vec![
480 VectorPoint {
481 id: "a".into(),
482 vector: vec![1.0, 0.0, 0.0, 0.0],
483 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
484 },
485 VectorPoint {
486 id: "b".into(),
487 vector: vec![1.0, 0.0, 0.0, 0.0],
488 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
489 },
490 ],
491 )
492 .await
493 .unwrap();
494
495 let filter = VectorFilter {
496 must: vec![FieldCondition {
497 field: "conv_id".into(),
498 value: FieldValue::Integer(1),
499 }],
500 must_not: vec![],
501 };
502 let results = vs
503 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
504 .await
505 .unwrap();
506 assert_eq!(results.len(), 1);
507 assert_eq!(results[0].id, "a");
508 }
509
510 #[tokio::test]
511 async fn search_empty_collection() {
512 let (vs, _) = setup().await;
513 vs.ensure_collection("c", 4).await.unwrap();
514 let results = vs
515 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
516 .await
517 .unwrap();
518 assert!(results.is_empty());
519 }
520
521 #[tokio::test]
522 async fn search_with_must_not_integer_filter() {
523 let (vs, _) = setup().await;
524 vs.ensure_collection("c", 4).await.unwrap();
525 vs.upsert(
526 "c",
527 vec![
528 VectorPoint {
529 id: "a".into(),
530 vector: vec![1.0, 0.0, 0.0, 0.0],
531 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
532 },
533 VectorPoint {
534 id: "b".into(),
535 vector: vec![1.0, 0.0, 0.0, 0.0],
536 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
537 },
538 ],
539 )
540 .await
541 .unwrap();
542
543 let filter = VectorFilter {
544 must: vec![],
545 must_not: vec![FieldCondition {
546 field: "conv_id".into(),
547 value: FieldValue::Integer(1),
548 }],
549 };
550 let results = vs
551 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
552 .await
553 .unwrap();
554 assert_eq!(results.len(), 1);
555 assert_eq!(results[0].id, "b");
556 }
557
558 #[tokio::test]
559 async fn search_with_combined_must_and_must_not() {
560 let (vs, _) = setup().await;
561 vs.ensure_collection("c", 4).await.unwrap();
562 vs.upsert(
563 "c",
564 vec![
565 VectorPoint {
566 id: "a".into(),
567 vector: vec![1.0, 0.0, 0.0, 0.0],
568 payload: HashMap::from([
569 ("role".into(), serde_json::json!("user")),
570 ("conv_id".into(), serde_json::json!(1)),
571 ]),
572 },
573 VectorPoint {
574 id: "b".into(),
575 vector: vec![1.0, 0.0, 0.0, 0.0],
576 payload: HashMap::from([
577 ("role".into(), serde_json::json!("user")),
578 ("conv_id".into(), serde_json::json!(2)),
579 ]),
580 },
581 VectorPoint {
582 id: "c".into(),
583 vector: vec![1.0, 0.0, 0.0, 0.0],
584 payload: HashMap::from([
585 ("role".into(), serde_json::json!("assistant")),
586 ("conv_id".into(), serde_json::json!(1)),
587 ]),
588 },
589 ],
590 )
591 .await
592 .unwrap();
593
594 let filter = VectorFilter {
595 must: vec![FieldCondition {
596 field: "role".into(),
597 value: FieldValue::Text("user".into()),
598 }],
599 must_not: vec![FieldCondition {
600 field: "conv_id".into(),
601 value: FieldValue::Integer(2),
602 }],
603 };
604 let results = vs
605 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
606 .await
607 .unwrap();
608 assert_eq!(results.len(), 1);
610 assert_eq!(results[0].id, "a");
611 }
612
613 #[tokio::test]
614 async fn scroll_all_missing_key_field() {
615 let (vs, _) = setup().await;
616 vs.ensure_collection("c", 4).await.unwrap();
617 vs.upsert(
618 "c",
619 vec![VectorPoint {
620 id: "p1".into(),
621 vector: vec![1.0, 0.0, 0.0, 0.0],
622 payload: HashMap::from([("other".into(), serde_json::json!("value"))]),
623 }],
624 )
625 .await
626 .unwrap();
627 let result = vs.scroll_all("c", "text").await.unwrap();
629 assert!(
630 result.is_empty(),
631 "points without the key field must not appear in scroll result"
632 );
633 }
634
635 #[tokio::test]
636 async fn delete_by_ids_empty_and_nonexistent() {
637 let (vs, _) = setup().await;
638 vs.ensure_collection("c", 4).await.unwrap();
639 vs.upsert(
640 "c",
641 vec![VectorPoint {
642 id: "a".into(),
643 vector: vec![1.0, 0.0, 0.0, 0.0],
644 payload: HashMap::new(),
645 }],
646 )
647 .await
648 .unwrap();
649
650 vs.delete_by_ids("c", vec![]).await.unwrap();
652
653 vs.delete_by_ids("c", vec!["nonexistent".into()])
655 .await
656 .unwrap();
657
658 let results = vs
660 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
661 .await
662 .unwrap();
663 assert_eq!(results.len(), 1);
664 assert_eq!(results[0].id, "a");
665 }
666
667 #[tokio::test]
668 async fn search_corrupt_blob_skipped() {
669 let (vs, store) = setup().await;
670 vs.ensure_collection("c", 4).await.unwrap();
671
672 vs.upsert(
674 "c",
675 vec![VectorPoint {
676 id: "valid".into(),
677 vector: vec![1.0, 0.0, 0.0, 0.0],
678 payload: HashMap::new(),
679 }],
680 )
681 .await
682 .unwrap();
683
684 let corrupt_blob: Vec<u8> = vec![0xFF, 0xFE, 0xFD];
687 let payload_json = r"{}";
688 sqlx::query(
689 "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?)",
690 )
691 .bind("corrupt")
692 .bind("c")
693 .bind(&corrupt_blob)
694 .bind(payload_json)
695 .execute(store.pool())
696 .await
697 .unwrap();
698
699 let results = vs
701 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
702 .await
703 .unwrap();
704 assert_eq!(results.len(), 1);
705 assert_eq!(results[0].id, "valid");
706 }
707}