1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4
5use sqlx::SqlitePool;
6
7use crate::vector_store::{
8 FieldValue, ScoredVectorPoint, ScrollResult, VectorFilter, VectorPoint, VectorStore,
9 VectorStoreError,
10};
11
12pub struct SqliteVectorStore {
13 pool: SqlitePool,
14}
15
16impl SqliteVectorStore {
17 #[must_use]
18 pub fn new(pool: SqlitePool) -> Self {
19 Self { pool }
20 }
21}
22
23fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
24 if a.len() != b.len() || a.is_empty() {
25 return 0.0;
26 }
27 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
28 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
29 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
30 if norm_a == 0.0 || norm_b == 0.0 {
31 return 0.0;
32 }
33 dot / (norm_a * norm_b)
34}
35
36fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
37 for cond in &filter.must {
38 let Some(val) = payload.get(&cond.field) else {
39 return false;
40 };
41 let matches = match &cond.value {
42 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
43 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
44 };
45 if !matches {
46 return false;
47 }
48 }
49 for cond in &filter.must_not {
50 let Some(val) = payload.get(&cond.field) else {
51 continue;
52 };
53 let matches = match &cond.value {
54 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
55 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
56 };
57 if matches {
58 return false;
59 }
60 }
61 true
62}
63
64type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
65
66impl VectorStore for SqliteVectorStore {
67 fn ensure_collection(
68 &self,
69 collection: &str,
70 _vector_size: u64,
71 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
72 let collection = collection.to_owned();
73 Box::pin(async move {
74 sqlx::query("INSERT OR IGNORE INTO vector_collections (name) VALUES (?)")
75 .bind(&collection)
76 .execute(&self.pool)
77 .await
78 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
79 Ok(())
80 })
81 }
82
83 fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
84 let collection = collection.to_owned();
85 Box::pin(async move {
86 let row: (i64,) =
87 sqlx::query_as("SELECT COUNT(*) FROM vector_collections WHERE name = ?")
88 .bind(&collection)
89 .fetch_one(&self.pool)
90 .await
91 .map_err(|e| VectorStoreError::Connection(e.to_string()))?;
92 Ok(row.0 > 0)
93 })
94 }
95
96 fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
97 let collection = collection.to_owned();
98 Box::pin(async move {
99 sqlx::query("DELETE FROM vector_points WHERE collection = ?")
100 .bind(&collection)
101 .execute(&self.pool)
102 .await
103 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
104 sqlx::query("DELETE FROM vector_collections WHERE name = ?")
105 .bind(&collection)
106 .execute(&self.pool)
107 .await
108 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
109 Ok(())
110 })
111 }
112
113 fn upsert(
114 &self,
115 collection: &str,
116 points: Vec<VectorPoint>,
117 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
118 let collection = collection.to_owned();
119 Box::pin(async move {
120 for point in points {
121 let vector_bytes: Vec<u8> = bytemuck::cast_slice(&point.vector).to_vec();
122 let payload_json = serde_json::to_string(&point.payload)
123 .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
124 sqlx::query(
125 "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?) \
126 ON CONFLICT(collection, id) DO UPDATE SET vector = excluded.vector, payload = excluded.payload",
127 )
128 .bind(&point.id)
129 .bind(&collection)
130 .bind(&vector_bytes)
131 .bind(&payload_json)
132 .execute(&self.pool)
133 .await
134 .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
135 }
136 Ok(())
137 })
138 }
139
140 fn search(
141 &self,
142 collection: &str,
143 vector: Vec<f32>,
144 limit: u64,
145 filter: Option<VectorFilter>,
146 ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
147 let collection = collection.to_owned();
148 Box::pin(async move {
149 let rows: Vec<(String, Vec<u8>, String)> = sqlx::query_as(
150 "SELECT id, vector, payload FROM vector_points WHERE collection = ?",
151 )
152 .bind(&collection)
153 .fetch_all(&self.pool)
154 .await
155 .map_err(|e| VectorStoreError::Search(e.to_string()))?;
156
157 let limit_usize = usize::try_from(limit).unwrap_or(usize::MAX);
158 let mut scored: Vec<ScoredVectorPoint> = rows
159 .into_iter()
160 .filter_map(|(id, blob, payload_str)| {
161 let Ok(stored) = bytemuck::try_cast_slice::<u8, f32>(&blob) else {
162 return None;
163 };
164 let payload: HashMap<String, serde_json::Value> =
165 serde_json::from_str(&payload_str).unwrap_or_default();
166
167 if filter
168 .as_ref()
169 .is_some_and(|f| !matches_filter(&payload, f))
170 {
171 return None;
172 }
173
174 let score = cosine_similarity(&vector, stored);
175 Some(ScoredVectorPoint { id, score, payload })
176 })
177 .collect();
178
179 scored.sort_by(|a, b| {
180 b.score
181 .partial_cmp(&a.score)
182 .unwrap_or(std::cmp::Ordering::Equal)
183 });
184 scored.truncate(limit_usize);
185 Ok(scored)
186 })
187 }
188
189 fn delete_by_ids(
190 &self,
191 collection: &str,
192 ids: Vec<String>,
193 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
194 let collection = collection.to_owned();
195 Box::pin(async move {
196 for id in ids {
197 sqlx::query("DELETE FROM vector_points WHERE collection = ? AND id = ?")
198 .bind(&collection)
199 .bind(&id)
200 .execute(&self.pool)
201 .await
202 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
203 }
204 Ok(())
205 })
206 }
207
208 fn scroll_all(
209 &self,
210 collection: &str,
211 key_field: &str,
212 ) -> BoxFuture<'_, Result<ScrollResult, VectorStoreError>> {
213 let collection = collection.to_owned();
214 let key_field = key_field.to_owned();
215 Box::pin(async move {
216 let rows: Vec<(String, String)> =
217 sqlx::query_as("SELECT id, payload FROM vector_points WHERE collection = ?")
218 .bind(&collection)
219 .fetch_all(&self.pool)
220 .await
221 .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
222
223 let mut result = ScrollResult::new();
224 for (id, payload_str) in rows {
225 let payload: HashMap<String, serde_json::Value> =
226 serde_json::from_str(&payload_str).unwrap_or_default();
227 if let Some(val) = payload.get(&key_field) {
228 let mut map = HashMap::new();
229 map.insert(
230 key_field.clone(),
231 val.as_str().unwrap_or_default().to_owned(),
232 );
233 result.insert(id, map);
234 }
235 }
236 Ok(result)
237 })
238 }
239
240 fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
241 Box::pin(async move {
242 sqlx::query_scalar::<_, i32>("SELECT 1")
243 .fetch_one(&self.pool)
244 .await
245 .map(|_| true)
246 .map_err(|e| VectorStoreError::Collection(e.to_string()))
247 })
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::sqlite::SqliteStore;
255 use crate::vector_store::FieldCondition;
256
257 async fn setup() -> (SqliteVectorStore, SqliteStore) {
258 let store = SqliteStore::new(":memory:").await.unwrap();
259 let pool = store.pool().clone();
260 let vs = SqliteVectorStore::new(pool);
261 (vs, store)
262 }
263
264 #[tokio::test]
265 async fn ensure_and_exists() {
266 let (vs, _) = setup().await;
267 assert!(!vs.collection_exists("col1").await.unwrap());
268 vs.ensure_collection("col1", 4).await.unwrap();
269 assert!(vs.collection_exists("col1").await.unwrap());
270 vs.ensure_collection("col1", 4).await.unwrap();
272 assert!(vs.collection_exists("col1").await.unwrap());
273 }
274
275 #[tokio::test]
276 async fn delete_collection() {
277 let (vs, _) = setup().await;
278 vs.ensure_collection("col1", 4).await.unwrap();
279 vs.upsert(
280 "col1",
281 vec![VectorPoint {
282 id: "p1".into(),
283 vector: vec![1.0, 0.0, 0.0, 0.0],
284 payload: HashMap::new(),
285 }],
286 )
287 .await
288 .unwrap();
289 vs.delete_collection("col1").await.unwrap();
290 assert!(!vs.collection_exists("col1").await.unwrap());
291 }
292
293 #[tokio::test]
294 async fn upsert_and_search() {
295 let (vs, _) = setup().await;
296 vs.ensure_collection("c", 4).await.unwrap();
297 vs.upsert(
298 "c",
299 vec![
300 VectorPoint {
301 id: "a".into(),
302 vector: vec![1.0, 0.0, 0.0, 0.0],
303 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
304 },
305 VectorPoint {
306 id: "b".into(),
307 vector: vec![0.0, 1.0, 0.0, 0.0],
308 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
309 },
310 ],
311 )
312 .await
313 .unwrap();
314
315 let results = vs
316 .search("c", vec![1.0, 0.0, 0.0, 0.0], 2, None)
317 .await
318 .unwrap();
319 assert_eq!(results.len(), 2);
320 assert_eq!(results[0].id, "a");
321 assert!((results[0].score - 1.0).abs() < 1e-5);
322 }
323
324 #[tokio::test]
325 async fn search_with_filter() {
326 let (vs, _) = setup().await;
327 vs.ensure_collection("c", 4).await.unwrap();
328 vs.upsert(
329 "c",
330 vec![
331 VectorPoint {
332 id: "a".into(),
333 vector: vec![1.0, 0.0, 0.0, 0.0],
334 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
335 },
336 VectorPoint {
337 id: "b".into(),
338 vector: vec![1.0, 0.0, 0.0, 0.0],
339 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
340 },
341 ],
342 )
343 .await
344 .unwrap();
345
346 let filter = VectorFilter {
347 must: vec![FieldCondition {
348 field: "role".into(),
349 value: FieldValue::Text("user".into()),
350 }],
351 must_not: vec![],
352 };
353 let results = vs
354 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
355 .await
356 .unwrap();
357 assert_eq!(results.len(), 1);
358 assert_eq!(results[0].id, "a");
359 }
360
361 #[tokio::test]
362 async fn delete_by_ids() {
363 let (vs, _) = setup().await;
364 vs.ensure_collection("c", 4).await.unwrap();
365 vs.upsert(
366 "c",
367 vec![
368 VectorPoint {
369 id: "a".into(),
370 vector: vec![1.0, 0.0, 0.0, 0.0],
371 payload: HashMap::new(),
372 },
373 VectorPoint {
374 id: "b".into(),
375 vector: vec![0.0, 1.0, 0.0, 0.0],
376 payload: HashMap::new(),
377 },
378 ],
379 )
380 .await
381 .unwrap();
382 vs.delete_by_ids("c", vec!["a".into()]).await.unwrap();
383 let results = vs
384 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
385 .await
386 .unwrap();
387 assert_eq!(results.len(), 1);
388 assert_eq!(results[0].id, "b");
389 }
390
391 #[tokio::test]
392 async fn scroll_all() {
393 let (vs, _) = setup().await;
394 vs.ensure_collection("c", 4).await.unwrap();
395 vs.upsert(
396 "c",
397 vec![VectorPoint {
398 id: "p1".into(),
399 vector: vec![1.0, 0.0, 0.0, 0.0],
400 payload: HashMap::from([("text".into(), serde_json::json!("hello"))]),
401 }],
402 )
403 .await
404 .unwrap();
405 let result = vs.scroll_all("c", "text").await.unwrap();
406 assert_eq!(result.len(), 1);
407 assert_eq!(result["p1"]["text"], "hello");
408 }
409
410 #[tokio::test]
411 async fn upsert_updates_existing() {
412 let (vs, _) = setup().await;
413 vs.ensure_collection("c", 4).await.unwrap();
414 vs.upsert(
415 "c",
416 vec![VectorPoint {
417 id: "p1".into(),
418 vector: vec![1.0, 0.0, 0.0, 0.0],
419 payload: HashMap::from([("v".into(), serde_json::json!(1))]),
420 }],
421 )
422 .await
423 .unwrap();
424 vs.upsert(
425 "c",
426 vec![VectorPoint {
427 id: "p1".into(),
428 vector: vec![0.0, 1.0, 0.0, 0.0],
429 payload: HashMap::from([("v".into(), serde_json::json!(2))]),
430 }],
431 )
432 .await
433 .unwrap();
434 let results = vs
435 .search("c", vec![0.0, 1.0, 0.0, 0.0], 1, None)
436 .await
437 .unwrap();
438 assert_eq!(results.len(), 1);
439 assert!((results[0].score - 1.0).abs() < 1e-5);
440 }
441
442 #[test]
443 fn cosine_similarity_orthogonal() {
444 assert_eq!(cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]), 0.0);
445 }
446
447 #[test]
448 fn cosine_similarity_identical() {
449 let v = vec![3.0, 4.0];
450 assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
451 }
452
453 #[test]
454 fn cosine_similarity_zero_vector() {
455 assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 0.0]), 0.0);
456 }
457
458 #[test]
459 fn cosine_similarity_length_mismatch() {
460 assert_eq!(cosine_similarity(&[1.0], &[1.0, 0.0]), 0.0);
461 }
462
463 #[tokio::test]
464 async fn search_with_must_not_filter() {
465 let (vs, _) = setup().await;
466 vs.ensure_collection("c", 4).await.unwrap();
467 vs.upsert(
468 "c",
469 vec![
470 VectorPoint {
471 id: "a".into(),
472 vector: vec![1.0, 0.0, 0.0, 0.0],
473 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
474 },
475 VectorPoint {
476 id: "b".into(),
477 vector: vec![1.0, 0.0, 0.0, 0.0],
478 payload: HashMap::from([("role".into(), serde_json::json!("system"))]),
479 },
480 ],
481 )
482 .await
483 .unwrap();
484
485 let filter = VectorFilter {
486 must: vec![],
487 must_not: vec![FieldCondition {
488 field: "role".into(),
489 value: FieldValue::Text("system".into()),
490 }],
491 };
492 let results = vs
493 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
494 .await
495 .unwrap();
496 assert_eq!(results.len(), 1);
497 assert_eq!(results[0].id, "a");
498 }
499
500 #[tokio::test]
501 async fn search_with_integer_filter() {
502 let (vs, _) = setup().await;
503 vs.ensure_collection("c", 4).await.unwrap();
504 vs.upsert(
505 "c",
506 vec![
507 VectorPoint {
508 id: "a".into(),
509 vector: vec![1.0, 0.0, 0.0, 0.0],
510 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
511 },
512 VectorPoint {
513 id: "b".into(),
514 vector: vec![1.0, 0.0, 0.0, 0.0],
515 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
516 },
517 ],
518 )
519 .await
520 .unwrap();
521
522 let filter = VectorFilter {
523 must: vec![FieldCondition {
524 field: "conv_id".into(),
525 value: FieldValue::Integer(1),
526 }],
527 must_not: vec![],
528 };
529 let results = vs
530 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
531 .await
532 .unwrap();
533 assert_eq!(results.len(), 1);
534 assert_eq!(results[0].id, "a");
535 }
536
537 #[tokio::test]
538 async fn search_empty_collection() {
539 let (vs, _) = setup().await;
540 vs.ensure_collection("c", 4).await.unwrap();
541 let results = vs
542 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
543 .await
544 .unwrap();
545 assert!(results.is_empty());
546 }
547
548 #[tokio::test]
549 async fn search_with_must_not_integer_filter() {
550 let (vs, _) = setup().await;
551 vs.ensure_collection("c", 4).await.unwrap();
552 vs.upsert(
553 "c",
554 vec![
555 VectorPoint {
556 id: "a".into(),
557 vector: vec![1.0, 0.0, 0.0, 0.0],
558 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
559 },
560 VectorPoint {
561 id: "b".into(),
562 vector: vec![1.0, 0.0, 0.0, 0.0],
563 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
564 },
565 ],
566 )
567 .await
568 .unwrap();
569
570 let filter = VectorFilter {
571 must: vec![],
572 must_not: vec![FieldCondition {
573 field: "conv_id".into(),
574 value: FieldValue::Integer(1),
575 }],
576 };
577 let results = vs
578 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
579 .await
580 .unwrap();
581 assert_eq!(results.len(), 1);
582 assert_eq!(results[0].id, "b");
583 }
584
585 #[tokio::test]
586 async fn search_with_combined_must_and_must_not() {
587 let (vs, _) = setup().await;
588 vs.ensure_collection("c", 4).await.unwrap();
589 vs.upsert(
590 "c",
591 vec![
592 VectorPoint {
593 id: "a".into(),
594 vector: vec![1.0, 0.0, 0.0, 0.0],
595 payload: HashMap::from([
596 ("role".into(), serde_json::json!("user")),
597 ("conv_id".into(), serde_json::json!(1)),
598 ]),
599 },
600 VectorPoint {
601 id: "b".into(),
602 vector: vec![1.0, 0.0, 0.0, 0.0],
603 payload: HashMap::from([
604 ("role".into(), serde_json::json!("user")),
605 ("conv_id".into(), serde_json::json!(2)),
606 ]),
607 },
608 VectorPoint {
609 id: "c".into(),
610 vector: vec![1.0, 0.0, 0.0, 0.0],
611 payload: HashMap::from([
612 ("role".into(), serde_json::json!("assistant")),
613 ("conv_id".into(), serde_json::json!(1)),
614 ]),
615 },
616 ],
617 )
618 .await
619 .unwrap();
620
621 let filter = VectorFilter {
622 must: vec![FieldCondition {
623 field: "role".into(),
624 value: FieldValue::Text("user".into()),
625 }],
626 must_not: vec![FieldCondition {
627 field: "conv_id".into(),
628 value: FieldValue::Integer(2),
629 }],
630 };
631 let results = vs
632 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
633 .await
634 .unwrap();
635 assert_eq!(results.len(), 1);
637 assert_eq!(results[0].id, "a");
638 }
639
640 #[tokio::test]
641 async fn scroll_all_missing_key_field() {
642 let (vs, _) = setup().await;
643 vs.ensure_collection("c", 4).await.unwrap();
644 vs.upsert(
645 "c",
646 vec![VectorPoint {
647 id: "p1".into(),
648 vector: vec![1.0, 0.0, 0.0, 0.0],
649 payload: HashMap::from([("other".into(), serde_json::json!("value"))]),
650 }],
651 )
652 .await
653 .unwrap();
654 let result = vs.scroll_all("c", "text").await.unwrap();
656 assert!(
657 result.is_empty(),
658 "points without the key field must not appear in scroll result"
659 );
660 }
661
662 #[tokio::test]
663 async fn delete_by_ids_empty_and_nonexistent() {
664 let (vs, _) = setup().await;
665 vs.ensure_collection("c", 4).await.unwrap();
666 vs.upsert(
667 "c",
668 vec![VectorPoint {
669 id: "a".into(),
670 vector: vec![1.0, 0.0, 0.0, 0.0],
671 payload: HashMap::new(),
672 }],
673 )
674 .await
675 .unwrap();
676
677 vs.delete_by_ids("c", vec![]).await.unwrap();
679
680 vs.delete_by_ids("c", vec!["nonexistent".into()])
682 .await
683 .unwrap();
684
685 let results = vs
687 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
688 .await
689 .unwrap();
690 assert_eq!(results.len(), 1);
691 assert_eq!(results[0].id, "a");
692 }
693
694 #[tokio::test]
695 async fn search_corrupt_blob_skipped() {
696 let (vs, store) = setup().await;
697 vs.ensure_collection("c", 4).await.unwrap();
698
699 vs.upsert(
701 "c",
702 vec![VectorPoint {
703 id: "valid".into(),
704 vector: vec![1.0, 0.0, 0.0, 0.0],
705 payload: HashMap::new(),
706 }],
707 )
708 .await
709 .unwrap();
710
711 let corrupt_blob: Vec<u8> = vec![0xFF, 0xFE, 0xFD];
714 let payload_json = r#"{}"#;
715 sqlx::query(
716 "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?)",
717 )
718 .bind("corrupt")
719 .bind("c")
720 .bind(&corrupt_blob)
721 .bind(payload_json)
722 .execute(store.pool())
723 .await
724 .unwrap();
725
726 let results = vs
728 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
729 .await
730 .unwrap();
731 assert_eq!(results.len(), 1);
732 assert_eq!(results[0].id, "valid");
733 }
734}