1use std::collections::HashMap;
11#[allow(unused_imports)]
12use zeph_db::sql;
13
14use zeph_db::{ActiveDialect, DbPool};
15
16use crate::vector_store::{
17 BoxFuture, FieldValue, ScoredVectorPoint, ScrollResult, VectorFilter, VectorPoint, VectorStore,
18 VectorStoreError,
19};
20
21pub struct DbVectorStore {
26 pool: DbPool,
27}
28
29pub type SqliteVectorStore = DbVectorStore;
31
32impl DbVectorStore {
33 #[must_use]
38 pub fn new(pool: DbPool) -> Self {
39 Self { pool }
40 }
41}
42
43use zeph_common::math::cosine_similarity;
44
45fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
46 for cond in &filter.must {
47 let Some(val) = payload.get(&cond.field) else {
48 return false;
49 };
50 let matches = match &cond.value {
51 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
52 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
53 };
54 if !matches {
55 return false;
56 }
57 }
58 for cond in &filter.must_not {
59 let Some(val) = payload.get(&cond.field) else {
60 continue;
61 };
62 let matches = match &cond.value {
63 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
64 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
65 };
66 if matches {
67 return false;
68 }
69 }
70 true
71}
72
73impl VectorStore for DbVectorStore {
74 fn ensure_collection(
75 &self,
76 collection: &str,
77 _vector_size: u64,
78 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
79 let collection = collection.to_owned();
80 Box::pin(async move {
81 let sql = format!(
82 "{} INTO vector_collections (name) VALUES (?){}",
83 <ActiveDialect as zeph_db::dialect::Dialect>::INSERT_IGNORE,
84 <ActiveDialect as zeph_db::dialect::Dialect>::CONFLICT_NOTHING,
85 );
86 zeph_db::query(&sql)
87 .bind(&collection)
88 .execute(&self.pool)
89 .await
90 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
91 Ok(())
92 })
93 }
94
95 fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
96 let collection = collection.to_owned();
97 Box::pin(async move {
98 let row: (i64,) = zeph_db::query_as(sql!(
99 "SELECT COUNT(*) FROM vector_collections WHERE name = ?"
100 ))
101 .bind(&collection)
102 .fetch_one(&self.pool)
103 .await
104 .map_err(|e| VectorStoreError::Connection(e.to_string()))?;
105 Ok(row.0 > 0)
106 })
107 }
108
109 fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
110 let collection = collection.to_owned();
111 Box::pin(async move {
112 zeph_db::query(sql!("DELETE FROM vector_points WHERE collection = ?"))
113 .bind(&collection)
114 .execute(&self.pool)
115 .await
116 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
117 zeph_db::query(sql!("DELETE FROM vector_collections WHERE name = ?"))
118 .bind(&collection)
119 .execute(&self.pool)
120 .await
121 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
122 Ok(())
123 })
124 }
125
126 fn upsert(
127 &self,
128 collection: &str,
129 points: Vec<VectorPoint>,
130 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
131 let collection = collection.to_owned();
132 #[cfg(feature = "profiling")]
133 let span = tracing::info_span!(
134 "memory.vector_store",
135 operation = "upsert",
136 collection = %collection
137 );
138 let fut = Box::pin(async move {
139 for point in points {
140 let vector_bytes: Vec<u8> =
141 point.vector.iter().flat_map(|f| f.to_le_bytes()).collect();
142 let payload_json = serde_json::to_string(&point.payload)
143 .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
144 zeph_db::query(
145 sql!("INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?) \
146 ON CONFLICT(collection, id) DO UPDATE SET vector = excluded.vector, payload = excluded.payload"),
147 )
148 .bind(&point.id)
149 .bind(&collection)
150 .bind(&vector_bytes)
151 .bind(&payload_json)
152 .execute(&self.pool)
153 .await
154 .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
155 }
156 Ok(())
157 });
158 #[cfg(feature = "profiling")]
159 return Box::pin(tracing::Instrument::instrument(fut, span));
160 #[cfg(not(feature = "profiling"))]
161 fut
162 }
163
164 fn search(
165 &self,
166 collection: &str,
167 vector: Vec<f32>,
168 limit: u64,
169 filter: Option<VectorFilter>,
170 ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
171 let collection = collection.to_owned();
172 #[cfg(feature = "profiling")]
173 let span = tracing::info_span!(
174 "memory.vector_store",
175 operation = "search",
176 collection = %collection
177 );
178 let fut = Box::pin(async move {
179 let rows: Vec<(String, Vec<u8>, String)> = zeph_db::query_as(sql!(
180 "SELECT id, vector, payload FROM vector_points WHERE collection = ?"
181 ))
182 .bind(&collection)
183 .fetch_all(&self.pool)
184 .await
185 .map_err(|e| VectorStoreError::Search(e.to_string()))?;
186
187 let limit_usize = usize::try_from(limit).unwrap_or(usize::MAX);
188 let mut scored: Vec<ScoredVectorPoint> = rows
189 .into_iter()
190 .filter_map(|(id, blob, payload_str)| {
191 if blob.len() % 4 != 0 {
192 return None;
193 }
194 let stored: Vec<f32> = blob
195 .chunks_exact(4)
196 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
197 .collect();
198 let payload: HashMap<String, serde_json::Value> =
199 serde_json::from_str(&payload_str).unwrap_or_default();
200
201 if filter
202 .as_ref()
203 .is_some_and(|f| !matches_filter(&payload, f))
204 {
205 return None;
206 }
207
208 let score = cosine_similarity(&vector, &stored);
209 Some(ScoredVectorPoint { id, score, payload })
210 })
211 .collect();
212
213 scored.sort_by(|a, b| {
214 b.score
215 .partial_cmp(&a.score)
216 .unwrap_or(std::cmp::Ordering::Equal)
217 });
218 scored.truncate(limit_usize);
219 Ok(scored)
220 });
221 #[cfg(feature = "profiling")]
222 return Box::pin(tracing::Instrument::instrument(fut, span));
223 #[cfg(not(feature = "profiling"))]
224 fut
225 }
226
227 fn delete_by_ids(
228 &self,
229 collection: &str,
230 ids: Vec<String>,
231 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
232 let collection = collection.to_owned();
233 #[cfg(feature = "profiling")]
234 let span = tracing::info_span!(
235 "memory.vector_store",
236 operation = "delete",
237 collection = %collection
238 );
239 let fut = Box::pin(async move {
240 for id in ids {
241 zeph_db::query(sql!(
242 "DELETE FROM vector_points WHERE collection = ? AND id = ?"
243 ))
244 .bind(&collection)
245 .bind(&id)
246 .execute(&self.pool)
247 .await
248 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
249 }
250 Ok(())
251 });
252 #[cfg(feature = "profiling")]
253 return Box::pin(tracing::Instrument::instrument(fut, span));
254 #[cfg(not(feature = "profiling"))]
255 fut
256 }
257
258 fn scroll_all(
259 &self,
260 collection: &str,
261 key_field: &str,
262 ) -> BoxFuture<'_, Result<ScrollResult, VectorStoreError>> {
263 let collection = collection.to_owned();
264 let key_field = key_field.to_owned();
265 Box::pin(async move {
266 let rows: Vec<(String, String)> = zeph_db::query_as(sql!(
267 "SELECT id, payload FROM vector_points WHERE collection = ?"
268 ))
269 .bind(&collection)
270 .fetch_all(&self.pool)
271 .await
272 .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
273
274 let mut result = ScrollResult::new();
275 for (id, payload_str) in rows {
276 let payload: HashMap<String, serde_json::Value> =
277 serde_json::from_str(&payload_str).unwrap_or_default();
278 if let Some(val) = payload.get(&key_field) {
279 let mut map = HashMap::new();
280 map.insert(
281 key_field.clone(),
282 val.as_str().unwrap_or_default().to_owned(),
283 );
284 result.insert(id, map);
285 }
286 }
287 Ok(result)
288 })
289 }
290
291 fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
292 Box::pin(async move {
293 zeph_db::query_scalar::<_, i32>(sql!("SELECT 1"))
294 .fetch_one(&self.pool)
295 .await
296 .map(|_| true)
297 .map_err(|e| VectorStoreError::Collection(e.to_string()))
298 })
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use crate::store::DbStore;
306 use crate::vector_store::FieldCondition;
307
308 async fn setup() -> (DbVectorStore, DbStore) {
309 let store = DbStore::new(":memory:").await.unwrap();
310 let pool = store.pool().clone();
311 let vs = DbVectorStore::new(pool);
312 (vs, store)
313 }
314
315 #[tokio::test]
316 async fn ensure_and_exists() {
317 let (vs, _) = setup().await;
318 assert!(!vs.collection_exists("col1").await.unwrap());
319 vs.ensure_collection("col1", 4).await.unwrap();
320 assert!(vs.collection_exists("col1").await.unwrap());
321 vs.ensure_collection("col1", 4).await.unwrap();
323 assert!(vs.collection_exists("col1").await.unwrap());
324 }
325
326 #[tokio::test]
327 async fn delete_collection() {
328 let (vs, _) = setup().await;
329 vs.ensure_collection("col1", 4).await.unwrap();
330 vs.upsert(
331 "col1",
332 vec![VectorPoint {
333 id: "p1".into(),
334 vector: vec![1.0, 0.0, 0.0, 0.0],
335 payload: HashMap::new(),
336 }],
337 )
338 .await
339 .unwrap();
340 vs.delete_collection("col1").await.unwrap();
341 assert!(!vs.collection_exists("col1").await.unwrap());
342 }
343
344 #[tokio::test]
345 async fn upsert_and_search() {
346 let (vs, _) = setup().await;
347 vs.ensure_collection("c", 4).await.unwrap();
348 vs.upsert(
349 "c",
350 vec![
351 VectorPoint {
352 id: "a".into(),
353 vector: vec![1.0, 0.0, 0.0, 0.0],
354 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
355 },
356 VectorPoint {
357 id: "b".into(),
358 vector: vec![0.0, 1.0, 0.0, 0.0],
359 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
360 },
361 ],
362 )
363 .await
364 .unwrap();
365
366 let results = vs
367 .search("c", vec![1.0, 0.0, 0.0, 0.0], 2, None)
368 .await
369 .unwrap();
370 assert_eq!(results.len(), 2);
371 assert_eq!(results[0].id, "a");
372 assert!((results[0].score - 1.0).abs() < 1e-5);
373 }
374
375 #[tokio::test]
376 async fn search_with_filter() {
377 let (vs, _) = setup().await;
378 vs.ensure_collection("c", 4).await.unwrap();
379 vs.upsert(
380 "c",
381 vec![
382 VectorPoint {
383 id: "a".into(),
384 vector: vec![1.0, 0.0, 0.0, 0.0],
385 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
386 },
387 VectorPoint {
388 id: "b".into(),
389 vector: vec![1.0, 0.0, 0.0, 0.0],
390 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
391 },
392 ],
393 )
394 .await
395 .unwrap();
396
397 let filter = VectorFilter {
398 must: vec![FieldCondition {
399 field: "role".into(),
400 value: FieldValue::Text("user".into()),
401 }],
402 must_not: vec![],
403 };
404 let results = vs
405 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
406 .await
407 .unwrap();
408 assert_eq!(results.len(), 1);
409 assert_eq!(results[0].id, "a");
410 }
411
412 #[tokio::test]
413 async fn delete_by_ids() {
414 let (vs, _) = setup().await;
415 vs.ensure_collection("c", 4).await.unwrap();
416 vs.upsert(
417 "c",
418 vec![
419 VectorPoint {
420 id: "a".into(),
421 vector: vec![1.0, 0.0, 0.0, 0.0],
422 payload: HashMap::new(),
423 },
424 VectorPoint {
425 id: "b".into(),
426 vector: vec![0.0, 1.0, 0.0, 0.0],
427 payload: HashMap::new(),
428 },
429 ],
430 )
431 .await
432 .unwrap();
433 vs.delete_by_ids("c", vec!["a".into()]).await.unwrap();
434 let results = vs
435 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
436 .await
437 .unwrap();
438 assert_eq!(results.len(), 1);
439 assert_eq!(results[0].id, "b");
440 }
441
442 #[tokio::test]
443 async fn scroll_all() {
444 let (vs, _) = setup().await;
445 vs.ensure_collection("c", 4).await.unwrap();
446 vs.upsert(
447 "c",
448 vec![VectorPoint {
449 id: "p1".into(),
450 vector: vec![1.0, 0.0, 0.0, 0.0],
451 payload: HashMap::from([("text".into(), serde_json::json!("hello"))]),
452 }],
453 )
454 .await
455 .unwrap();
456 let result = vs.scroll_all("c", "text").await.unwrap();
457 assert_eq!(result.len(), 1);
458 assert_eq!(result["p1"]["text"], "hello");
459 }
460
461 #[tokio::test]
462 async fn upsert_updates_existing() {
463 let (vs, _) = setup().await;
464 vs.ensure_collection("c", 4).await.unwrap();
465 vs.upsert(
466 "c",
467 vec![VectorPoint {
468 id: "p1".into(),
469 vector: vec![1.0, 0.0, 0.0, 0.0],
470 payload: HashMap::from([("v".into(), serde_json::json!(1))]),
471 }],
472 )
473 .await
474 .unwrap();
475 vs.upsert(
476 "c",
477 vec![VectorPoint {
478 id: "p1".into(),
479 vector: vec![0.0, 1.0, 0.0, 0.0],
480 payload: HashMap::from([("v".into(), serde_json::json!(2))]),
481 }],
482 )
483 .await
484 .unwrap();
485 let results = vs
486 .search("c", vec![0.0, 1.0, 0.0, 0.0], 1, None)
487 .await
488 .unwrap();
489 assert_eq!(results.len(), 1);
490 assert!((results[0].score - 1.0).abs() < 1e-5);
491 }
492
493 #[test]
494 fn cosine_similarity_import_wired() {
495 assert!(!cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).is_nan());
497 }
498
499 #[tokio::test]
500 async fn search_with_must_not_filter() {
501 let (vs, _) = setup().await;
502 vs.ensure_collection("c", 4).await.unwrap();
503 vs.upsert(
504 "c",
505 vec![
506 VectorPoint {
507 id: "a".into(),
508 vector: vec![1.0, 0.0, 0.0, 0.0],
509 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
510 },
511 VectorPoint {
512 id: "b".into(),
513 vector: vec![1.0, 0.0, 0.0, 0.0],
514 payload: HashMap::from([("role".into(), serde_json::json!("system"))]),
515 },
516 ],
517 )
518 .await
519 .unwrap();
520
521 let filter = VectorFilter {
522 must: vec![],
523 must_not: vec![FieldCondition {
524 field: "role".into(),
525 value: FieldValue::Text("system".into()),
526 }],
527 };
528 let results = vs
529 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
530 .await
531 .unwrap();
532 assert_eq!(results.len(), 1);
533 assert_eq!(results[0].id, "a");
534 }
535
536 #[tokio::test]
537 async fn search_with_integer_filter() {
538 let (vs, _) = setup().await;
539 vs.ensure_collection("c", 4).await.unwrap();
540 vs.upsert(
541 "c",
542 vec![
543 VectorPoint {
544 id: "a".into(),
545 vector: vec![1.0, 0.0, 0.0, 0.0],
546 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
547 },
548 VectorPoint {
549 id: "b".into(),
550 vector: vec![1.0, 0.0, 0.0, 0.0],
551 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
552 },
553 ],
554 )
555 .await
556 .unwrap();
557
558 let filter = VectorFilter {
559 must: vec![FieldCondition {
560 field: "conv_id".into(),
561 value: FieldValue::Integer(1),
562 }],
563 must_not: vec![],
564 };
565 let results = vs
566 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
567 .await
568 .unwrap();
569 assert_eq!(results.len(), 1);
570 assert_eq!(results[0].id, "a");
571 }
572
573 #[tokio::test]
574 async fn search_empty_collection() {
575 let (vs, _) = setup().await;
576 vs.ensure_collection("c", 4).await.unwrap();
577 let results = vs
578 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
579 .await
580 .unwrap();
581 assert!(results.is_empty());
582 }
583
584 #[tokio::test]
585 async fn search_with_must_not_integer_filter() {
586 let (vs, _) = setup().await;
587 vs.ensure_collection("c", 4).await.unwrap();
588 vs.upsert(
589 "c",
590 vec![
591 VectorPoint {
592 id: "a".into(),
593 vector: vec![1.0, 0.0, 0.0, 0.0],
594 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
595 },
596 VectorPoint {
597 id: "b".into(),
598 vector: vec![1.0, 0.0, 0.0, 0.0],
599 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
600 },
601 ],
602 )
603 .await
604 .unwrap();
605
606 let filter = VectorFilter {
607 must: vec![],
608 must_not: vec![FieldCondition {
609 field: "conv_id".into(),
610 value: FieldValue::Integer(1),
611 }],
612 };
613 let results = vs
614 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
615 .await
616 .unwrap();
617 assert_eq!(results.len(), 1);
618 assert_eq!(results[0].id, "b");
619 }
620
621 #[tokio::test]
622 async fn search_with_combined_must_and_must_not() {
623 let (vs, _) = setup().await;
624 vs.ensure_collection("c", 4).await.unwrap();
625 vs.upsert(
626 "c",
627 vec![
628 VectorPoint {
629 id: "a".into(),
630 vector: vec![1.0, 0.0, 0.0, 0.0],
631 payload: HashMap::from([
632 ("role".into(), serde_json::json!("user")),
633 ("conv_id".into(), serde_json::json!(1)),
634 ]),
635 },
636 VectorPoint {
637 id: "b".into(),
638 vector: vec![1.0, 0.0, 0.0, 0.0],
639 payload: HashMap::from([
640 ("role".into(), serde_json::json!("user")),
641 ("conv_id".into(), serde_json::json!(2)),
642 ]),
643 },
644 VectorPoint {
645 id: "c".into(),
646 vector: vec![1.0, 0.0, 0.0, 0.0],
647 payload: HashMap::from([
648 ("role".into(), serde_json::json!("assistant")),
649 ("conv_id".into(), serde_json::json!(1)),
650 ]),
651 },
652 ],
653 )
654 .await
655 .unwrap();
656
657 let filter = VectorFilter {
658 must: vec![FieldCondition {
659 field: "role".into(),
660 value: FieldValue::Text("user".into()),
661 }],
662 must_not: vec![FieldCondition {
663 field: "conv_id".into(),
664 value: FieldValue::Integer(2),
665 }],
666 };
667 let results = vs
668 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
669 .await
670 .unwrap();
671 assert_eq!(results.len(), 1);
673 assert_eq!(results[0].id, "a");
674 }
675
676 #[tokio::test]
677 async fn scroll_all_missing_key_field() {
678 let (vs, _) = setup().await;
679 vs.ensure_collection("c", 4).await.unwrap();
680 vs.upsert(
681 "c",
682 vec![VectorPoint {
683 id: "p1".into(),
684 vector: vec![1.0, 0.0, 0.0, 0.0],
685 payload: HashMap::from([("other".into(), serde_json::json!("value"))]),
686 }],
687 )
688 .await
689 .unwrap();
690 let result = vs.scroll_all("c", "text").await.unwrap();
692 assert!(
693 result.is_empty(),
694 "points without the key field must not appear in scroll result"
695 );
696 }
697
698 #[tokio::test]
699 async fn delete_by_ids_empty_and_nonexistent() {
700 let (vs, _) = setup().await;
701 vs.ensure_collection("c", 4).await.unwrap();
702 vs.upsert(
703 "c",
704 vec![VectorPoint {
705 id: "a".into(),
706 vector: vec![1.0, 0.0, 0.0, 0.0],
707 payload: HashMap::new(),
708 }],
709 )
710 .await
711 .unwrap();
712
713 vs.delete_by_ids("c", vec![]).await.unwrap();
715
716 vs.delete_by_ids("c", vec!["nonexistent".into()])
718 .await
719 .unwrap();
720
721 let results = vs
723 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
724 .await
725 .unwrap();
726 assert_eq!(results.len(), 1);
727 assert_eq!(results[0].id, "a");
728 }
729
730 #[tokio::test]
731 async fn search_corrupt_blob_skipped() {
732 let (vs, store) = setup().await;
733 vs.ensure_collection("c", 4).await.unwrap();
734
735 vs.upsert(
737 "c",
738 vec![VectorPoint {
739 id: "valid".into(),
740 vector: vec![1.0, 0.0, 0.0, 0.0],
741 payload: HashMap::new(),
742 }],
743 )
744 .await
745 .unwrap();
746
747 let corrupt_blob: Vec<u8> = vec![0xFF, 0xFE, 0xFD];
750 let payload_json = r"{}";
751 zeph_db::query(sql!(
752 "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?)"
753 ))
754 .bind("corrupt")
755 .bind("c")
756 .bind(&corrupt_blob)
757 .bind(payload_json)
758 .execute(store.pool())
759 .await
760 .unwrap();
761
762 let results = vs
764 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
765 .await
766 .unwrap();
767 assert_eq!(results.len(), 1);
768 assert_eq!(results[0].id, "valid");
769 }
770}