1use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7
8use sqlx::SqlitePool;
9
10use crate::vector_store::{
11 FieldValue, ScoredVectorPoint, ScrollResult, VectorFilter, VectorPoint, VectorStore,
12 VectorStoreError,
13};
14
15pub struct SqliteVectorStore {
16 pool: SqlitePool,
17}
18
19impl SqliteVectorStore {
20 #[must_use]
21 pub fn new(pool: SqlitePool) -> Self {
22 Self { pool }
23 }
24}
25
26fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
27 if a.len() != b.len() || a.is_empty() {
28 return 0.0;
29 }
30 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
31 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
32 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
33 if norm_a == 0.0 || norm_b == 0.0 {
34 return 0.0;
35 }
36 dot / (norm_a * norm_b)
37}
38
39fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
40 for cond in &filter.must {
41 let Some(val) = payload.get(&cond.field) else {
42 return false;
43 };
44 let matches = match &cond.value {
45 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
46 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
47 };
48 if !matches {
49 return false;
50 }
51 }
52 for cond in &filter.must_not {
53 let Some(val) = payload.get(&cond.field) else {
54 continue;
55 };
56 let matches = match &cond.value {
57 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
58 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
59 };
60 if matches {
61 return false;
62 }
63 }
64 true
65}
66
67type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
68
69impl VectorStore for SqliteVectorStore {
70 fn ensure_collection(
71 &self,
72 collection: &str,
73 _vector_size: u64,
74 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
75 let collection = collection.to_owned();
76 Box::pin(async move {
77 sqlx::query("INSERT OR IGNORE INTO vector_collections (name) VALUES (?)")
78 .bind(&collection)
79 .execute(&self.pool)
80 .await
81 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
82 Ok(())
83 })
84 }
85
86 fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
87 let collection = collection.to_owned();
88 Box::pin(async move {
89 let row: (i64,) =
90 sqlx::query_as("SELECT COUNT(*) FROM vector_collections WHERE name = ?")
91 .bind(&collection)
92 .fetch_one(&self.pool)
93 .await
94 .map_err(|e| VectorStoreError::Connection(e.to_string()))?;
95 Ok(row.0 > 0)
96 })
97 }
98
99 fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
100 let collection = collection.to_owned();
101 Box::pin(async move {
102 sqlx::query("DELETE FROM vector_points WHERE collection = ?")
103 .bind(&collection)
104 .execute(&self.pool)
105 .await
106 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
107 sqlx::query("DELETE FROM vector_collections WHERE name = ?")
108 .bind(&collection)
109 .execute(&self.pool)
110 .await
111 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
112 Ok(())
113 })
114 }
115
116 fn upsert(
117 &self,
118 collection: &str,
119 points: Vec<VectorPoint>,
120 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
121 let collection = collection.to_owned();
122 Box::pin(async move {
123 for point in points {
124 let vector_bytes: Vec<u8> = bytemuck::cast_slice(&point.vector).to_vec();
125 let payload_json = serde_json::to_string(&point.payload)
126 .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
127 sqlx::query(
128 "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?) \
129 ON CONFLICT(collection, id) DO UPDATE SET vector = excluded.vector, payload = excluded.payload",
130 )
131 .bind(&point.id)
132 .bind(&collection)
133 .bind(&vector_bytes)
134 .bind(&payload_json)
135 .execute(&self.pool)
136 .await
137 .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
138 }
139 Ok(())
140 })
141 }
142
143 fn search(
144 &self,
145 collection: &str,
146 vector: Vec<f32>,
147 limit: u64,
148 filter: Option<VectorFilter>,
149 ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
150 let collection = collection.to_owned();
151 Box::pin(async move {
152 let rows: Vec<(String, Vec<u8>, String)> = sqlx::query_as(
153 "SELECT id, vector, payload FROM vector_points WHERE collection = ?",
154 )
155 .bind(&collection)
156 .fetch_all(&self.pool)
157 .await
158 .map_err(|e| VectorStoreError::Search(e.to_string()))?;
159
160 let limit_usize = usize::try_from(limit).unwrap_or(usize::MAX);
161 let mut scored: Vec<ScoredVectorPoint> = rows
162 .into_iter()
163 .filter_map(|(id, blob, payload_str)| {
164 let Ok(stored) = bytemuck::try_cast_slice::<u8, f32>(&blob) else {
165 return None;
166 };
167 let payload: HashMap<String, serde_json::Value> =
168 serde_json::from_str(&payload_str).unwrap_or_default();
169
170 if filter
171 .as_ref()
172 .is_some_and(|f| !matches_filter(&payload, f))
173 {
174 return None;
175 }
176
177 let score = cosine_similarity(&vector, stored);
178 Some(ScoredVectorPoint { id, score, payload })
179 })
180 .collect();
181
182 scored.sort_by(|a, b| {
183 b.score
184 .partial_cmp(&a.score)
185 .unwrap_or(std::cmp::Ordering::Equal)
186 });
187 scored.truncate(limit_usize);
188 Ok(scored)
189 })
190 }
191
192 fn delete_by_ids(
193 &self,
194 collection: &str,
195 ids: Vec<String>,
196 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
197 let collection = collection.to_owned();
198 Box::pin(async move {
199 for id in ids {
200 sqlx::query("DELETE FROM vector_points WHERE collection = ? AND id = ?")
201 .bind(&collection)
202 .bind(&id)
203 .execute(&self.pool)
204 .await
205 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
206 }
207 Ok(())
208 })
209 }
210
211 fn scroll_all(
212 &self,
213 collection: &str,
214 key_field: &str,
215 ) -> BoxFuture<'_, Result<ScrollResult, VectorStoreError>> {
216 let collection = collection.to_owned();
217 let key_field = key_field.to_owned();
218 Box::pin(async move {
219 let rows: Vec<(String, String)> =
220 sqlx::query_as("SELECT id, payload FROM vector_points WHERE collection = ?")
221 .bind(&collection)
222 .fetch_all(&self.pool)
223 .await
224 .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
225
226 let mut result = ScrollResult::new();
227 for (id, payload_str) in rows {
228 let payload: HashMap<String, serde_json::Value> =
229 serde_json::from_str(&payload_str).unwrap_or_default();
230 if let Some(val) = payload.get(&key_field) {
231 let mut map = HashMap::new();
232 map.insert(
233 key_field.clone(),
234 val.as_str().unwrap_or_default().to_owned(),
235 );
236 result.insert(id, map);
237 }
238 }
239 Ok(result)
240 })
241 }
242
243 fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
244 Box::pin(async move {
245 sqlx::query_scalar::<_, i32>("SELECT 1")
246 .fetch_one(&self.pool)
247 .await
248 .map(|_| true)
249 .map_err(|e| VectorStoreError::Collection(e.to_string()))
250 })
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use crate::sqlite::SqliteStore;
258 use crate::vector_store::FieldCondition;
259
260 async fn setup() -> (SqliteVectorStore, SqliteStore) {
261 let store = SqliteStore::new(":memory:").await.unwrap();
262 let pool = store.pool().clone();
263 let vs = SqliteVectorStore::new(pool);
264 (vs, store)
265 }
266
267 #[tokio::test]
268 async fn ensure_and_exists() {
269 let (vs, _) = setup().await;
270 assert!(!vs.collection_exists("col1").await.unwrap());
271 vs.ensure_collection("col1", 4).await.unwrap();
272 assert!(vs.collection_exists("col1").await.unwrap());
273 vs.ensure_collection("col1", 4).await.unwrap();
275 assert!(vs.collection_exists("col1").await.unwrap());
276 }
277
278 #[tokio::test]
279 async fn delete_collection() {
280 let (vs, _) = setup().await;
281 vs.ensure_collection("col1", 4).await.unwrap();
282 vs.upsert(
283 "col1",
284 vec![VectorPoint {
285 id: "p1".into(),
286 vector: vec![1.0, 0.0, 0.0, 0.0],
287 payload: HashMap::new(),
288 }],
289 )
290 .await
291 .unwrap();
292 vs.delete_collection("col1").await.unwrap();
293 assert!(!vs.collection_exists("col1").await.unwrap());
294 }
295
296 #[tokio::test]
297 async fn upsert_and_search() {
298 let (vs, _) = setup().await;
299 vs.ensure_collection("c", 4).await.unwrap();
300 vs.upsert(
301 "c",
302 vec![
303 VectorPoint {
304 id: "a".into(),
305 vector: vec![1.0, 0.0, 0.0, 0.0],
306 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
307 },
308 VectorPoint {
309 id: "b".into(),
310 vector: vec![0.0, 1.0, 0.0, 0.0],
311 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
312 },
313 ],
314 )
315 .await
316 .unwrap();
317
318 let results = vs
319 .search("c", vec![1.0, 0.0, 0.0, 0.0], 2, None)
320 .await
321 .unwrap();
322 assert_eq!(results.len(), 2);
323 assert_eq!(results[0].id, "a");
324 assert!((results[0].score - 1.0).abs() < 1e-5);
325 }
326
327 #[tokio::test]
328 async fn search_with_filter() {
329 let (vs, _) = setup().await;
330 vs.ensure_collection("c", 4).await.unwrap();
331 vs.upsert(
332 "c",
333 vec![
334 VectorPoint {
335 id: "a".into(),
336 vector: vec![1.0, 0.0, 0.0, 0.0],
337 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
338 },
339 VectorPoint {
340 id: "b".into(),
341 vector: vec![1.0, 0.0, 0.0, 0.0],
342 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
343 },
344 ],
345 )
346 .await
347 .unwrap();
348
349 let filter = VectorFilter {
350 must: vec![FieldCondition {
351 field: "role".into(),
352 value: FieldValue::Text("user".into()),
353 }],
354 must_not: vec![],
355 };
356 let results = vs
357 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
358 .await
359 .unwrap();
360 assert_eq!(results.len(), 1);
361 assert_eq!(results[0].id, "a");
362 }
363
364 #[tokio::test]
365 async fn delete_by_ids() {
366 let (vs, _) = setup().await;
367 vs.ensure_collection("c", 4).await.unwrap();
368 vs.upsert(
369 "c",
370 vec![
371 VectorPoint {
372 id: "a".into(),
373 vector: vec![1.0, 0.0, 0.0, 0.0],
374 payload: HashMap::new(),
375 },
376 VectorPoint {
377 id: "b".into(),
378 vector: vec![0.0, 1.0, 0.0, 0.0],
379 payload: HashMap::new(),
380 },
381 ],
382 )
383 .await
384 .unwrap();
385 vs.delete_by_ids("c", vec!["a".into()]).await.unwrap();
386 let results = vs
387 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
388 .await
389 .unwrap();
390 assert_eq!(results.len(), 1);
391 assert_eq!(results[0].id, "b");
392 }
393
394 #[tokio::test]
395 async fn scroll_all() {
396 let (vs, _) = setup().await;
397 vs.ensure_collection("c", 4).await.unwrap();
398 vs.upsert(
399 "c",
400 vec![VectorPoint {
401 id: "p1".into(),
402 vector: vec![1.0, 0.0, 0.0, 0.0],
403 payload: HashMap::from([("text".into(), serde_json::json!("hello"))]),
404 }],
405 )
406 .await
407 .unwrap();
408 let result = vs.scroll_all("c", "text").await.unwrap();
409 assert_eq!(result.len(), 1);
410 assert_eq!(result["p1"]["text"], "hello");
411 }
412
413 #[tokio::test]
414 async fn upsert_updates_existing() {
415 let (vs, _) = setup().await;
416 vs.ensure_collection("c", 4).await.unwrap();
417 vs.upsert(
418 "c",
419 vec![VectorPoint {
420 id: "p1".into(),
421 vector: vec![1.0, 0.0, 0.0, 0.0],
422 payload: HashMap::from([("v".into(), serde_json::json!(1))]),
423 }],
424 )
425 .await
426 .unwrap();
427 vs.upsert(
428 "c",
429 vec![VectorPoint {
430 id: "p1".into(),
431 vector: vec![0.0, 1.0, 0.0, 0.0],
432 payload: HashMap::from([("v".into(), serde_json::json!(2))]),
433 }],
434 )
435 .await
436 .unwrap();
437 let results = vs
438 .search("c", vec![0.0, 1.0, 0.0, 0.0], 1, None)
439 .await
440 .unwrap();
441 assert_eq!(results.len(), 1);
442 assert!((results[0].score - 1.0).abs() < 1e-5);
443 }
444
445 #[test]
446 fn cosine_similarity_orthogonal() {
447 assert_eq!(cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]), 0.0);
448 }
449
450 #[test]
451 fn cosine_similarity_identical() {
452 let v = vec![3.0, 4.0];
453 assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
454 }
455
456 #[test]
457 fn cosine_similarity_zero_vector() {
458 assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 0.0]), 0.0);
459 }
460
461 #[test]
462 fn cosine_similarity_length_mismatch() {
463 assert_eq!(cosine_similarity(&[1.0], &[1.0, 0.0]), 0.0);
464 }
465
466 #[tokio::test]
467 async fn search_with_must_not_filter() {
468 let (vs, _) = setup().await;
469 vs.ensure_collection("c", 4).await.unwrap();
470 vs.upsert(
471 "c",
472 vec![
473 VectorPoint {
474 id: "a".into(),
475 vector: vec![1.0, 0.0, 0.0, 0.0],
476 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
477 },
478 VectorPoint {
479 id: "b".into(),
480 vector: vec![1.0, 0.0, 0.0, 0.0],
481 payload: HashMap::from([("role".into(), serde_json::json!("system"))]),
482 },
483 ],
484 )
485 .await
486 .unwrap();
487
488 let filter = VectorFilter {
489 must: vec![],
490 must_not: vec![FieldCondition {
491 field: "role".into(),
492 value: FieldValue::Text("system".into()),
493 }],
494 };
495 let results = vs
496 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
497 .await
498 .unwrap();
499 assert_eq!(results.len(), 1);
500 assert_eq!(results[0].id, "a");
501 }
502
503 #[tokio::test]
504 async fn search_with_integer_filter() {
505 let (vs, _) = setup().await;
506 vs.ensure_collection("c", 4).await.unwrap();
507 vs.upsert(
508 "c",
509 vec![
510 VectorPoint {
511 id: "a".into(),
512 vector: vec![1.0, 0.0, 0.0, 0.0],
513 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
514 },
515 VectorPoint {
516 id: "b".into(),
517 vector: vec![1.0, 0.0, 0.0, 0.0],
518 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
519 },
520 ],
521 )
522 .await
523 .unwrap();
524
525 let filter = VectorFilter {
526 must: vec![FieldCondition {
527 field: "conv_id".into(),
528 value: FieldValue::Integer(1),
529 }],
530 must_not: vec![],
531 };
532 let results = vs
533 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
534 .await
535 .unwrap();
536 assert_eq!(results.len(), 1);
537 assert_eq!(results[0].id, "a");
538 }
539
540 #[tokio::test]
541 async fn search_empty_collection() {
542 let (vs, _) = setup().await;
543 vs.ensure_collection("c", 4).await.unwrap();
544 let results = vs
545 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
546 .await
547 .unwrap();
548 assert!(results.is_empty());
549 }
550
551 #[tokio::test]
552 async fn search_with_must_not_integer_filter() {
553 let (vs, _) = setup().await;
554 vs.ensure_collection("c", 4).await.unwrap();
555 vs.upsert(
556 "c",
557 vec![
558 VectorPoint {
559 id: "a".into(),
560 vector: vec![1.0, 0.0, 0.0, 0.0],
561 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
562 },
563 VectorPoint {
564 id: "b".into(),
565 vector: vec![1.0, 0.0, 0.0, 0.0],
566 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
567 },
568 ],
569 )
570 .await
571 .unwrap();
572
573 let filter = VectorFilter {
574 must: vec![],
575 must_not: vec![FieldCondition {
576 field: "conv_id".into(),
577 value: FieldValue::Integer(1),
578 }],
579 };
580 let results = vs
581 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
582 .await
583 .unwrap();
584 assert_eq!(results.len(), 1);
585 assert_eq!(results[0].id, "b");
586 }
587
588 #[tokio::test]
589 async fn search_with_combined_must_and_must_not() {
590 let (vs, _) = setup().await;
591 vs.ensure_collection("c", 4).await.unwrap();
592 vs.upsert(
593 "c",
594 vec![
595 VectorPoint {
596 id: "a".into(),
597 vector: vec![1.0, 0.0, 0.0, 0.0],
598 payload: HashMap::from([
599 ("role".into(), serde_json::json!("user")),
600 ("conv_id".into(), serde_json::json!(1)),
601 ]),
602 },
603 VectorPoint {
604 id: "b".into(),
605 vector: vec![1.0, 0.0, 0.0, 0.0],
606 payload: HashMap::from([
607 ("role".into(), serde_json::json!("user")),
608 ("conv_id".into(), serde_json::json!(2)),
609 ]),
610 },
611 VectorPoint {
612 id: "c".into(),
613 vector: vec![1.0, 0.0, 0.0, 0.0],
614 payload: HashMap::from([
615 ("role".into(), serde_json::json!("assistant")),
616 ("conv_id".into(), serde_json::json!(1)),
617 ]),
618 },
619 ],
620 )
621 .await
622 .unwrap();
623
624 let filter = VectorFilter {
625 must: vec![FieldCondition {
626 field: "role".into(),
627 value: FieldValue::Text("user".into()),
628 }],
629 must_not: vec![FieldCondition {
630 field: "conv_id".into(),
631 value: FieldValue::Integer(2),
632 }],
633 };
634 let results = vs
635 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
636 .await
637 .unwrap();
638 assert_eq!(results.len(), 1);
640 assert_eq!(results[0].id, "a");
641 }
642
643 #[tokio::test]
644 async fn scroll_all_missing_key_field() {
645 let (vs, _) = setup().await;
646 vs.ensure_collection("c", 4).await.unwrap();
647 vs.upsert(
648 "c",
649 vec![VectorPoint {
650 id: "p1".into(),
651 vector: vec![1.0, 0.0, 0.0, 0.0],
652 payload: HashMap::from([("other".into(), serde_json::json!("value"))]),
653 }],
654 )
655 .await
656 .unwrap();
657 let result = vs.scroll_all("c", "text").await.unwrap();
659 assert!(
660 result.is_empty(),
661 "points without the key field must not appear in scroll result"
662 );
663 }
664
665 #[tokio::test]
666 async fn delete_by_ids_empty_and_nonexistent() {
667 let (vs, _) = setup().await;
668 vs.ensure_collection("c", 4).await.unwrap();
669 vs.upsert(
670 "c",
671 vec![VectorPoint {
672 id: "a".into(),
673 vector: vec![1.0, 0.0, 0.0, 0.0],
674 payload: HashMap::new(),
675 }],
676 )
677 .await
678 .unwrap();
679
680 vs.delete_by_ids("c", vec![]).await.unwrap();
682
683 vs.delete_by_ids("c", vec!["nonexistent".into()])
685 .await
686 .unwrap();
687
688 let results = vs
690 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
691 .await
692 .unwrap();
693 assert_eq!(results.len(), 1);
694 assert_eq!(results[0].id, "a");
695 }
696
697 #[tokio::test]
698 async fn search_corrupt_blob_skipped() {
699 let (vs, store) = setup().await;
700 vs.ensure_collection("c", 4).await.unwrap();
701
702 vs.upsert(
704 "c",
705 vec![VectorPoint {
706 id: "valid".into(),
707 vector: vec![1.0, 0.0, 0.0, 0.0],
708 payload: HashMap::new(),
709 }],
710 )
711 .await
712 .unwrap();
713
714 let corrupt_blob: Vec<u8> = vec![0xFF, 0xFE, 0xFD];
717 let payload_json = r#"{}"#;
718 sqlx::query(
719 "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?)",
720 )
721 .bind("corrupt")
722 .bind("c")
723 .bind(&corrupt_blob)
724 .bind(payload_json)
725 .execute(store.pool())
726 .await
727 .unwrap();
728
729 let results = vs
731 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
732 .await
733 .unwrap();
734 assert_eq!(results.len(), 1);
735 assert_eq!(results[0].id, "valid");
736 }
737}