1use std::collections::HashMap;
5#[allow(unused_imports)]
6use zeph_db::sql;
7
8use zeph_db::{ActiveDialect, DbPool};
9
10use crate::vector_store::{
11 BoxFuture, FieldValue, ScoredVectorPoint, ScrollResult, VectorFilter, VectorPoint, VectorStore,
12 VectorStoreError,
13};
14
15pub struct DbVectorStore {
20 pool: DbPool,
21}
22
23pub type SqliteVectorStore = DbVectorStore;
25
26impl DbVectorStore {
27 #[must_use]
28 pub fn new(pool: DbPool) -> Self {
29 Self { pool }
30 }
31}
32
33use crate::math::cosine_similarity;
34
35fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
36 for cond in &filter.must {
37 let Some(val) = payload.get(&cond.field) else {
38 return false;
39 };
40 let matches = match &cond.value {
41 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
42 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
43 };
44 if !matches {
45 return false;
46 }
47 }
48 for cond in &filter.must_not {
49 let Some(val) = payload.get(&cond.field) else {
50 continue;
51 };
52 let matches = match &cond.value {
53 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
54 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
55 };
56 if matches {
57 return false;
58 }
59 }
60 true
61}
62
63impl VectorStore for DbVectorStore {
64 fn ensure_collection(
65 &self,
66 collection: &str,
67 _vector_size: u64,
68 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
69 let collection = collection.to_owned();
70 Box::pin(async move {
71 let sql = format!(
72 "{} INTO vector_collections (name) VALUES (?){}",
73 <ActiveDialect as zeph_db::dialect::Dialect>::INSERT_IGNORE,
74 <ActiveDialect as zeph_db::dialect::Dialect>::CONFLICT_NOTHING,
75 );
76 zeph_db::query(&sql)
77 .bind(&collection)
78 .execute(&self.pool)
79 .await
80 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
81 Ok(())
82 })
83 }
84
85 fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
86 let collection = collection.to_owned();
87 Box::pin(async move {
88 let row: (i64,) = zeph_db::query_as(sql!(
89 "SELECT COUNT(*) FROM vector_collections WHERE name = ?"
90 ))
91 .bind(&collection)
92 .fetch_one(&self.pool)
93 .await
94 .map_err(|e| VectorStoreError::Connection(e.to_string()))?;
95 Ok(row.0 > 0)
96 })
97 }
98
99 fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
100 let collection = collection.to_owned();
101 Box::pin(async move {
102 zeph_db::query(sql!("DELETE FROM vector_points WHERE collection = ?"))
103 .bind(&collection)
104 .execute(&self.pool)
105 .await
106 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
107 zeph_db::query(sql!("DELETE FROM vector_collections WHERE name = ?"))
108 .bind(&collection)
109 .execute(&self.pool)
110 .await
111 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
112 Ok(())
113 })
114 }
115
116 fn upsert(
117 &self,
118 collection: &str,
119 points: Vec<VectorPoint>,
120 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
121 let collection = collection.to_owned();
122 Box::pin(async move {
123 for point in points {
124 let vector_bytes: Vec<u8> =
125 point.vector.iter().flat_map(|f| f.to_le_bytes()).collect();
126 let payload_json = serde_json::to_string(&point.payload)
127 .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
128 zeph_db::query(
129 sql!("INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?) \
130 ON CONFLICT(collection, id) DO UPDATE SET vector = excluded.vector, payload = excluded.payload"),
131 )
132 .bind(&point.id)
133 .bind(&collection)
134 .bind(&vector_bytes)
135 .bind(&payload_json)
136 .execute(&self.pool)
137 .await
138 .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
139 }
140 Ok(())
141 })
142 }
143
144 fn search(
145 &self,
146 collection: &str,
147 vector: Vec<f32>,
148 limit: u64,
149 filter: Option<VectorFilter>,
150 ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
151 let collection = collection.to_owned();
152 Box::pin(async move {
153 let rows: Vec<(String, Vec<u8>, String)> = zeph_db::query_as(sql!(
154 "SELECT id, vector, payload FROM vector_points WHERE collection = ?"
155 ))
156 .bind(&collection)
157 .fetch_all(&self.pool)
158 .await
159 .map_err(|e| VectorStoreError::Search(e.to_string()))?;
160
161 let limit_usize = usize::try_from(limit).unwrap_or(usize::MAX);
162 let mut scored: Vec<ScoredVectorPoint> = rows
163 .into_iter()
164 .filter_map(|(id, blob, payload_str)| {
165 if blob.len() % 4 != 0 {
166 return None;
167 }
168 let stored: Vec<f32> = blob
169 .chunks_exact(4)
170 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
171 .collect();
172 let payload: HashMap<String, serde_json::Value> =
173 serde_json::from_str(&payload_str).unwrap_or_default();
174
175 if filter
176 .as_ref()
177 .is_some_and(|f| !matches_filter(&payload, f))
178 {
179 return None;
180 }
181
182 let score = cosine_similarity(&vector, &stored);
183 Some(ScoredVectorPoint { id, score, payload })
184 })
185 .collect();
186
187 scored.sort_by(|a, b| {
188 b.score
189 .partial_cmp(&a.score)
190 .unwrap_or(std::cmp::Ordering::Equal)
191 });
192 scored.truncate(limit_usize);
193 Ok(scored)
194 })
195 }
196
197 fn delete_by_ids(
198 &self,
199 collection: &str,
200 ids: Vec<String>,
201 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
202 let collection = collection.to_owned();
203 Box::pin(async move {
204 for id in ids {
205 zeph_db::query(sql!(
206 "DELETE FROM vector_points WHERE collection = ? AND id = ?"
207 ))
208 .bind(&collection)
209 .bind(&id)
210 .execute(&self.pool)
211 .await
212 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
213 }
214 Ok(())
215 })
216 }
217
218 fn scroll_all(
219 &self,
220 collection: &str,
221 key_field: &str,
222 ) -> BoxFuture<'_, Result<ScrollResult, VectorStoreError>> {
223 let collection = collection.to_owned();
224 let key_field = key_field.to_owned();
225 Box::pin(async move {
226 let rows: Vec<(String, String)> = zeph_db::query_as(sql!(
227 "SELECT id, payload FROM vector_points WHERE collection = ?"
228 ))
229 .bind(&collection)
230 .fetch_all(&self.pool)
231 .await
232 .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
233
234 let mut result = ScrollResult::new();
235 for (id, payload_str) in rows {
236 let payload: HashMap<String, serde_json::Value> =
237 serde_json::from_str(&payload_str).unwrap_or_default();
238 if let Some(val) = payload.get(&key_field) {
239 let mut map = HashMap::new();
240 map.insert(
241 key_field.clone(),
242 val.as_str().unwrap_or_default().to_owned(),
243 );
244 result.insert(id, map);
245 }
246 }
247 Ok(result)
248 })
249 }
250
251 fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
252 Box::pin(async move {
253 zeph_db::query_scalar::<_, i32>(sql!("SELECT 1"))
254 .fetch_one(&self.pool)
255 .await
256 .map(|_| true)
257 .map_err(|e| VectorStoreError::Collection(e.to_string()))
258 })
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::store::DbStore;
266 use crate::vector_store::FieldCondition;
267
268 async fn setup() -> (DbVectorStore, DbStore) {
269 let store = DbStore::new(":memory:").await.unwrap();
270 let pool = store.pool().clone();
271 let vs = DbVectorStore::new(pool);
272 (vs, store)
273 }
274
275 #[tokio::test]
276 async fn ensure_and_exists() {
277 let (vs, _) = setup().await;
278 assert!(!vs.collection_exists("col1").await.unwrap());
279 vs.ensure_collection("col1", 4).await.unwrap();
280 assert!(vs.collection_exists("col1").await.unwrap());
281 vs.ensure_collection("col1", 4).await.unwrap();
283 assert!(vs.collection_exists("col1").await.unwrap());
284 }
285
286 #[tokio::test]
287 async fn delete_collection() {
288 let (vs, _) = setup().await;
289 vs.ensure_collection("col1", 4).await.unwrap();
290 vs.upsert(
291 "col1",
292 vec![VectorPoint {
293 id: "p1".into(),
294 vector: vec![1.0, 0.0, 0.0, 0.0],
295 payload: HashMap::new(),
296 }],
297 )
298 .await
299 .unwrap();
300 vs.delete_collection("col1").await.unwrap();
301 assert!(!vs.collection_exists("col1").await.unwrap());
302 }
303
304 #[tokio::test]
305 async fn upsert_and_search() {
306 let (vs, _) = setup().await;
307 vs.ensure_collection("c", 4).await.unwrap();
308 vs.upsert(
309 "c",
310 vec![
311 VectorPoint {
312 id: "a".into(),
313 vector: vec![1.0, 0.0, 0.0, 0.0],
314 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
315 },
316 VectorPoint {
317 id: "b".into(),
318 vector: vec![0.0, 1.0, 0.0, 0.0],
319 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
320 },
321 ],
322 )
323 .await
324 .unwrap();
325
326 let results = vs
327 .search("c", vec![1.0, 0.0, 0.0, 0.0], 2, None)
328 .await
329 .unwrap();
330 assert_eq!(results.len(), 2);
331 assert_eq!(results[0].id, "a");
332 assert!((results[0].score - 1.0).abs() < 1e-5);
333 }
334
335 #[tokio::test]
336 async fn search_with_filter() {
337 let (vs, _) = setup().await;
338 vs.ensure_collection("c", 4).await.unwrap();
339 vs.upsert(
340 "c",
341 vec![
342 VectorPoint {
343 id: "a".into(),
344 vector: vec![1.0, 0.0, 0.0, 0.0],
345 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
346 },
347 VectorPoint {
348 id: "b".into(),
349 vector: vec![1.0, 0.0, 0.0, 0.0],
350 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
351 },
352 ],
353 )
354 .await
355 .unwrap();
356
357 let filter = VectorFilter {
358 must: vec![FieldCondition {
359 field: "role".into(),
360 value: FieldValue::Text("user".into()),
361 }],
362 must_not: vec![],
363 };
364 let results = vs
365 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
366 .await
367 .unwrap();
368 assert_eq!(results.len(), 1);
369 assert_eq!(results[0].id, "a");
370 }
371
372 #[tokio::test]
373 async fn delete_by_ids() {
374 let (vs, _) = setup().await;
375 vs.ensure_collection("c", 4).await.unwrap();
376 vs.upsert(
377 "c",
378 vec![
379 VectorPoint {
380 id: "a".into(),
381 vector: vec![1.0, 0.0, 0.0, 0.0],
382 payload: HashMap::new(),
383 },
384 VectorPoint {
385 id: "b".into(),
386 vector: vec![0.0, 1.0, 0.0, 0.0],
387 payload: HashMap::new(),
388 },
389 ],
390 )
391 .await
392 .unwrap();
393 vs.delete_by_ids("c", vec!["a".into()]).await.unwrap();
394 let results = vs
395 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
396 .await
397 .unwrap();
398 assert_eq!(results.len(), 1);
399 assert_eq!(results[0].id, "b");
400 }
401
402 #[tokio::test]
403 async fn scroll_all() {
404 let (vs, _) = setup().await;
405 vs.ensure_collection("c", 4).await.unwrap();
406 vs.upsert(
407 "c",
408 vec![VectorPoint {
409 id: "p1".into(),
410 vector: vec![1.0, 0.0, 0.0, 0.0],
411 payload: HashMap::from([("text".into(), serde_json::json!("hello"))]),
412 }],
413 )
414 .await
415 .unwrap();
416 let result = vs.scroll_all("c", "text").await.unwrap();
417 assert_eq!(result.len(), 1);
418 assert_eq!(result["p1"]["text"], "hello");
419 }
420
421 #[tokio::test]
422 async fn upsert_updates_existing() {
423 let (vs, _) = setup().await;
424 vs.ensure_collection("c", 4).await.unwrap();
425 vs.upsert(
426 "c",
427 vec![VectorPoint {
428 id: "p1".into(),
429 vector: vec![1.0, 0.0, 0.0, 0.0],
430 payload: HashMap::from([("v".into(), serde_json::json!(1))]),
431 }],
432 )
433 .await
434 .unwrap();
435 vs.upsert(
436 "c",
437 vec![VectorPoint {
438 id: "p1".into(),
439 vector: vec![0.0, 1.0, 0.0, 0.0],
440 payload: HashMap::from([("v".into(), serde_json::json!(2))]),
441 }],
442 )
443 .await
444 .unwrap();
445 let results = vs
446 .search("c", vec![0.0, 1.0, 0.0, 0.0], 1, None)
447 .await
448 .unwrap();
449 assert_eq!(results.len(), 1);
450 assert!((results[0].score - 1.0).abs() < 1e-5);
451 }
452
453 #[test]
454 fn cosine_similarity_import_wired() {
455 assert!(!cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).is_nan());
457 }
458
459 #[tokio::test]
460 async fn search_with_must_not_filter() {
461 let (vs, _) = setup().await;
462 vs.ensure_collection("c", 4).await.unwrap();
463 vs.upsert(
464 "c",
465 vec![
466 VectorPoint {
467 id: "a".into(),
468 vector: vec![1.0, 0.0, 0.0, 0.0],
469 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
470 },
471 VectorPoint {
472 id: "b".into(),
473 vector: vec![1.0, 0.0, 0.0, 0.0],
474 payload: HashMap::from([("role".into(), serde_json::json!("system"))]),
475 },
476 ],
477 )
478 .await
479 .unwrap();
480
481 let filter = VectorFilter {
482 must: vec![],
483 must_not: vec![FieldCondition {
484 field: "role".into(),
485 value: FieldValue::Text("system".into()),
486 }],
487 };
488 let results = vs
489 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
490 .await
491 .unwrap();
492 assert_eq!(results.len(), 1);
493 assert_eq!(results[0].id, "a");
494 }
495
496 #[tokio::test]
497 async fn search_with_integer_filter() {
498 let (vs, _) = setup().await;
499 vs.ensure_collection("c", 4).await.unwrap();
500 vs.upsert(
501 "c",
502 vec![
503 VectorPoint {
504 id: "a".into(),
505 vector: vec![1.0, 0.0, 0.0, 0.0],
506 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
507 },
508 VectorPoint {
509 id: "b".into(),
510 vector: vec![1.0, 0.0, 0.0, 0.0],
511 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
512 },
513 ],
514 )
515 .await
516 .unwrap();
517
518 let filter = VectorFilter {
519 must: vec![FieldCondition {
520 field: "conv_id".into(),
521 value: FieldValue::Integer(1),
522 }],
523 must_not: vec![],
524 };
525 let results = vs
526 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
527 .await
528 .unwrap();
529 assert_eq!(results.len(), 1);
530 assert_eq!(results[0].id, "a");
531 }
532
533 #[tokio::test]
534 async fn search_empty_collection() {
535 let (vs, _) = setup().await;
536 vs.ensure_collection("c", 4).await.unwrap();
537 let results = vs
538 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
539 .await
540 .unwrap();
541 assert!(results.is_empty());
542 }
543
544 #[tokio::test]
545 async fn search_with_must_not_integer_filter() {
546 let (vs, _) = setup().await;
547 vs.ensure_collection("c", 4).await.unwrap();
548 vs.upsert(
549 "c",
550 vec![
551 VectorPoint {
552 id: "a".into(),
553 vector: vec![1.0, 0.0, 0.0, 0.0],
554 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
555 },
556 VectorPoint {
557 id: "b".into(),
558 vector: vec![1.0, 0.0, 0.0, 0.0],
559 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
560 },
561 ],
562 )
563 .await
564 .unwrap();
565
566 let filter = VectorFilter {
567 must: vec![],
568 must_not: vec![FieldCondition {
569 field: "conv_id".into(),
570 value: FieldValue::Integer(1),
571 }],
572 };
573 let results = vs
574 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
575 .await
576 .unwrap();
577 assert_eq!(results.len(), 1);
578 assert_eq!(results[0].id, "b");
579 }
580
581 #[tokio::test]
582 async fn search_with_combined_must_and_must_not() {
583 let (vs, _) = setup().await;
584 vs.ensure_collection("c", 4).await.unwrap();
585 vs.upsert(
586 "c",
587 vec![
588 VectorPoint {
589 id: "a".into(),
590 vector: vec![1.0, 0.0, 0.0, 0.0],
591 payload: HashMap::from([
592 ("role".into(), serde_json::json!("user")),
593 ("conv_id".into(), serde_json::json!(1)),
594 ]),
595 },
596 VectorPoint {
597 id: "b".into(),
598 vector: vec![1.0, 0.0, 0.0, 0.0],
599 payload: HashMap::from([
600 ("role".into(), serde_json::json!("user")),
601 ("conv_id".into(), serde_json::json!(2)),
602 ]),
603 },
604 VectorPoint {
605 id: "c".into(),
606 vector: vec![1.0, 0.0, 0.0, 0.0],
607 payload: HashMap::from([
608 ("role".into(), serde_json::json!("assistant")),
609 ("conv_id".into(), serde_json::json!(1)),
610 ]),
611 },
612 ],
613 )
614 .await
615 .unwrap();
616
617 let filter = VectorFilter {
618 must: vec![FieldCondition {
619 field: "role".into(),
620 value: FieldValue::Text("user".into()),
621 }],
622 must_not: vec![FieldCondition {
623 field: "conv_id".into(),
624 value: FieldValue::Integer(2),
625 }],
626 };
627 let results = vs
628 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
629 .await
630 .unwrap();
631 assert_eq!(results.len(), 1);
633 assert_eq!(results[0].id, "a");
634 }
635
636 #[tokio::test]
637 async fn scroll_all_missing_key_field() {
638 let (vs, _) = setup().await;
639 vs.ensure_collection("c", 4).await.unwrap();
640 vs.upsert(
641 "c",
642 vec![VectorPoint {
643 id: "p1".into(),
644 vector: vec![1.0, 0.0, 0.0, 0.0],
645 payload: HashMap::from([("other".into(), serde_json::json!("value"))]),
646 }],
647 )
648 .await
649 .unwrap();
650 let result = vs.scroll_all("c", "text").await.unwrap();
652 assert!(
653 result.is_empty(),
654 "points without the key field must not appear in scroll result"
655 );
656 }
657
658 #[tokio::test]
659 async fn delete_by_ids_empty_and_nonexistent() {
660 let (vs, _) = setup().await;
661 vs.ensure_collection("c", 4).await.unwrap();
662 vs.upsert(
663 "c",
664 vec![VectorPoint {
665 id: "a".into(),
666 vector: vec![1.0, 0.0, 0.0, 0.0],
667 payload: HashMap::new(),
668 }],
669 )
670 .await
671 .unwrap();
672
673 vs.delete_by_ids("c", vec![]).await.unwrap();
675
676 vs.delete_by_ids("c", vec!["nonexistent".into()])
678 .await
679 .unwrap();
680
681 let results = vs
683 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
684 .await
685 .unwrap();
686 assert_eq!(results.len(), 1);
687 assert_eq!(results[0].id, "a");
688 }
689
690 #[tokio::test]
691 async fn search_corrupt_blob_skipped() {
692 let (vs, store) = setup().await;
693 vs.ensure_collection("c", 4).await.unwrap();
694
695 vs.upsert(
697 "c",
698 vec![VectorPoint {
699 id: "valid".into(),
700 vector: vec![1.0, 0.0, 0.0, 0.0],
701 payload: HashMap::new(),
702 }],
703 )
704 .await
705 .unwrap();
706
707 let corrupt_blob: Vec<u8> = vec![0xFF, 0xFE, 0xFD];
710 let payload_json = r"{}";
711 zeph_db::query(sql!(
712 "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?)"
713 ))
714 .bind("corrupt")
715 .bind("c")
716 .bind(&corrupt_blob)
717 .bind(payload_json)
718 .execute(store.pool())
719 .await
720 .unwrap();
721
722 let results = vs
724 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
725 .await
726 .unwrap();
727 assert_eq!(results.len(), 1);
728 assert_eq!(results[0].id, "valid");
729 }
730}