Skip to main content

zeph_memory/
db_vector_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5#[allow(unused_imports)]
6use zeph_db::sql;
7
8use zeph_db::{ActiveDialect, DbPool};
9
10use crate::vector_store::{
11    BoxFuture, FieldValue, ScoredVectorPoint, ScrollResult, VectorFilter, VectorPoint, VectorStore,
12    VectorStoreError,
13};
14
15/// Database-backed in-process vector store.
16///
17/// Stores vectors as BLOBs in `SQLite` and performs cosine similarity in memory.
18/// For production-scale workloads, prefer the Qdrant-backed store.
19pub struct DbVectorStore {
20    pool: DbPool,
21}
22
23/// Backward-compatible alias.
24pub type SqliteVectorStore = DbVectorStore;
25
26impl DbVectorStore {
27    #[must_use]
28    pub fn new(pool: DbPool) -> Self {
29        Self { pool }
30    }
31}
32
33use crate::math::cosine_similarity;
34
35fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
36    for cond in &filter.must {
37        let Some(val) = payload.get(&cond.field) else {
38            return false;
39        };
40        let matches = match &cond.value {
41            FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
42            FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
43        };
44        if !matches {
45            return false;
46        }
47    }
48    for cond in &filter.must_not {
49        let Some(val) = payload.get(&cond.field) else {
50            continue;
51        };
52        let matches = match &cond.value {
53            FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
54            FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
55        };
56        if matches {
57            return false;
58        }
59    }
60    true
61}
62
63impl VectorStore for DbVectorStore {
64    fn ensure_collection(
65        &self,
66        collection: &str,
67        _vector_size: u64,
68    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
69        let collection = collection.to_owned();
70        Box::pin(async move {
71            let sql = format!(
72                "{} INTO vector_collections (name) VALUES (?){}",
73                <ActiveDialect as zeph_db::dialect::Dialect>::INSERT_IGNORE,
74                <ActiveDialect as zeph_db::dialect::Dialect>::CONFLICT_NOTHING,
75            );
76            zeph_db::query(&sql)
77                .bind(&collection)
78                .execute(&self.pool)
79                .await
80                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
81            Ok(())
82        })
83    }
84
85    fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
86        let collection = collection.to_owned();
87        Box::pin(async move {
88            let row: (i64,) = zeph_db::query_as(sql!(
89                "SELECT COUNT(*) FROM vector_collections WHERE name = ?"
90            ))
91            .bind(&collection)
92            .fetch_one(&self.pool)
93            .await
94            .map_err(|e| VectorStoreError::Connection(e.to_string()))?;
95            Ok(row.0 > 0)
96        })
97    }
98
99    fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
100        let collection = collection.to_owned();
101        Box::pin(async move {
102            zeph_db::query(sql!("DELETE FROM vector_points WHERE collection = ?"))
103                .bind(&collection)
104                .execute(&self.pool)
105                .await
106                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
107            zeph_db::query(sql!("DELETE FROM vector_collections WHERE name = ?"))
108                .bind(&collection)
109                .execute(&self.pool)
110                .await
111                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
112            Ok(())
113        })
114    }
115
116    fn upsert(
117        &self,
118        collection: &str,
119        points: Vec<VectorPoint>,
120    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
121        let collection = collection.to_owned();
122        Box::pin(async move {
123            for point in points {
124                let vector_bytes: Vec<u8> =
125                    point.vector.iter().flat_map(|f| f.to_le_bytes()).collect();
126                let payload_json = serde_json::to_string(&point.payload)
127                    .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
128                zeph_db::query(
129                    sql!("INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?) \
130                     ON CONFLICT(collection, id) DO UPDATE SET vector = excluded.vector, payload = excluded.payload"),
131                )
132                .bind(&point.id)
133                .bind(&collection)
134                .bind(&vector_bytes)
135                .bind(&payload_json)
136                .execute(&self.pool)
137                .await
138                .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
139            }
140            Ok(())
141        })
142    }
143
144    fn search(
145        &self,
146        collection: &str,
147        vector: Vec<f32>,
148        limit: u64,
149        filter: Option<VectorFilter>,
150    ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
151        let collection = collection.to_owned();
152        Box::pin(async move {
153            let rows: Vec<(String, Vec<u8>, String)> = zeph_db::query_as(sql!(
154                "SELECT id, vector, payload FROM vector_points WHERE collection = ?"
155            ))
156            .bind(&collection)
157            .fetch_all(&self.pool)
158            .await
159            .map_err(|e| VectorStoreError::Search(e.to_string()))?;
160
161            let limit_usize = usize::try_from(limit).unwrap_or(usize::MAX);
162            let mut scored: Vec<ScoredVectorPoint> = rows
163                .into_iter()
164                .filter_map(|(id, blob, payload_str)| {
165                    if blob.len() % 4 != 0 {
166                        return None;
167                    }
168                    let stored: Vec<f32> = blob
169                        .chunks_exact(4)
170                        .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
171                        .collect();
172                    let payload: HashMap<String, serde_json::Value> =
173                        serde_json::from_str(&payload_str).unwrap_or_default();
174
175                    if filter
176                        .as_ref()
177                        .is_some_and(|f| !matches_filter(&payload, f))
178                    {
179                        return None;
180                    }
181
182                    let score = cosine_similarity(&vector, &stored);
183                    Some(ScoredVectorPoint { id, score, payload })
184                })
185                .collect();
186
187            scored.sort_by(|a, b| {
188                b.score
189                    .partial_cmp(&a.score)
190                    .unwrap_or(std::cmp::Ordering::Equal)
191            });
192            scored.truncate(limit_usize);
193            Ok(scored)
194        })
195    }
196
197    fn delete_by_ids(
198        &self,
199        collection: &str,
200        ids: Vec<String>,
201    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
202        let collection = collection.to_owned();
203        Box::pin(async move {
204            for id in ids {
205                zeph_db::query(sql!(
206                    "DELETE FROM vector_points WHERE collection = ? AND id = ?"
207                ))
208                .bind(&collection)
209                .bind(&id)
210                .execute(&self.pool)
211                .await
212                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
213            }
214            Ok(())
215        })
216    }
217
218    fn scroll_all(
219        &self,
220        collection: &str,
221        key_field: &str,
222    ) -> BoxFuture<'_, Result<ScrollResult, VectorStoreError>> {
223        let collection = collection.to_owned();
224        let key_field = key_field.to_owned();
225        Box::pin(async move {
226            let rows: Vec<(String, String)> = zeph_db::query_as(sql!(
227                "SELECT id, payload FROM vector_points WHERE collection = ?"
228            ))
229            .bind(&collection)
230            .fetch_all(&self.pool)
231            .await
232            .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
233
234            let mut result = ScrollResult::new();
235            for (id, payload_str) in rows {
236                let payload: HashMap<String, serde_json::Value> =
237                    serde_json::from_str(&payload_str).unwrap_or_default();
238                if let Some(val) = payload.get(&key_field) {
239                    let mut map = HashMap::new();
240                    map.insert(
241                        key_field.clone(),
242                        val.as_str().unwrap_or_default().to_owned(),
243                    );
244                    result.insert(id, map);
245                }
246            }
247            Ok(result)
248        })
249    }
250
251    fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
252        Box::pin(async move {
253            zeph_db::query_scalar::<_, i32>(sql!("SELECT 1"))
254                .fetch_one(&self.pool)
255                .await
256                .map(|_| true)
257                .map_err(|e| VectorStoreError::Collection(e.to_string()))
258        })
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::store::DbStore;
266    use crate::vector_store::FieldCondition;
267
268    async fn setup() -> (DbVectorStore, DbStore) {
269        let store = DbStore::new(":memory:").await.unwrap();
270        let pool = store.pool().clone();
271        let vs = DbVectorStore::new(pool);
272        (vs, store)
273    }
274
275    #[tokio::test]
276    async fn ensure_and_exists() {
277        let (vs, _) = setup().await;
278        assert!(!vs.collection_exists("col1").await.unwrap());
279        vs.ensure_collection("col1", 4).await.unwrap();
280        assert!(vs.collection_exists("col1").await.unwrap());
281        // idempotent
282        vs.ensure_collection("col1", 4).await.unwrap();
283        assert!(vs.collection_exists("col1").await.unwrap());
284    }
285
286    #[tokio::test]
287    async fn delete_collection() {
288        let (vs, _) = setup().await;
289        vs.ensure_collection("col1", 4).await.unwrap();
290        vs.upsert(
291            "col1",
292            vec![VectorPoint {
293                id: "p1".into(),
294                vector: vec![1.0, 0.0, 0.0, 0.0],
295                payload: HashMap::new(),
296            }],
297        )
298        .await
299        .unwrap();
300        vs.delete_collection("col1").await.unwrap();
301        assert!(!vs.collection_exists("col1").await.unwrap());
302    }
303
304    #[tokio::test]
305    async fn upsert_and_search() {
306        let (vs, _) = setup().await;
307        vs.ensure_collection("c", 4).await.unwrap();
308        vs.upsert(
309            "c",
310            vec![
311                VectorPoint {
312                    id: "a".into(),
313                    vector: vec![1.0, 0.0, 0.0, 0.0],
314                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
315                },
316                VectorPoint {
317                    id: "b".into(),
318                    vector: vec![0.0, 1.0, 0.0, 0.0],
319                    payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
320                },
321            ],
322        )
323        .await
324        .unwrap();
325
326        let results = vs
327            .search("c", vec![1.0, 0.0, 0.0, 0.0], 2, None)
328            .await
329            .unwrap();
330        assert_eq!(results.len(), 2);
331        assert_eq!(results[0].id, "a");
332        assert!((results[0].score - 1.0).abs() < 1e-5);
333    }
334
335    #[tokio::test]
336    async fn search_with_filter() {
337        let (vs, _) = setup().await;
338        vs.ensure_collection("c", 4).await.unwrap();
339        vs.upsert(
340            "c",
341            vec![
342                VectorPoint {
343                    id: "a".into(),
344                    vector: vec![1.0, 0.0, 0.0, 0.0],
345                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
346                },
347                VectorPoint {
348                    id: "b".into(),
349                    vector: vec![1.0, 0.0, 0.0, 0.0],
350                    payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
351                },
352            ],
353        )
354        .await
355        .unwrap();
356
357        let filter = VectorFilter {
358            must: vec![FieldCondition {
359                field: "role".into(),
360                value: FieldValue::Text("user".into()),
361            }],
362            must_not: vec![],
363        };
364        let results = vs
365            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
366            .await
367            .unwrap();
368        assert_eq!(results.len(), 1);
369        assert_eq!(results[0].id, "a");
370    }
371
372    #[tokio::test]
373    async fn delete_by_ids() {
374        let (vs, _) = setup().await;
375        vs.ensure_collection("c", 4).await.unwrap();
376        vs.upsert(
377            "c",
378            vec![
379                VectorPoint {
380                    id: "a".into(),
381                    vector: vec![1.0, 0.0, 0.0, 0.0],
382                    payload: HashMap::new(),
383                },
384                VectorPoint {
385                    id: "b".into(),
386                    vector: vec![0.0, 1.0, 0.0, 0.0],
387                    payload: HashMap::new(),
388                },
389            ],
390        )
391        .await
392        .unwrap();
393        vs.delete_by_ids("c", vec!["a".into()]).await.unwrap();
394        let results = vs
395            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
396            .await
397            .unwrap();
398        assert_eq!(results.len(), 1);
399        assert_eq!(results[0].id, "b");
400    }
401
402    #[tokio::test]
403    async fn scroll_all() {
404        let (vs, _) = setup().await;
405        vs.ensure_collection("c", 4).await.unwrap();
406        vs.upsert(
407            "c",
408            vec![VectorPoint {
409                id: "p1".into(),
410                vector: vec![1.0, 0.0, 0.0, 0.0],
411                payload: HashMap::from([("text".into(), serde_json::json!("hello"))]),
412            }],
413        )
414        .await
415        .unwrap();
416        let result = vs.scroll_all("c", "text").await.unwrap();
417        assert_eq!(result.len(), 1);
418        assert_eq!(result["p1"]["text"], "hello");
419    }
420
421    #[tokio::test]
422    async fn upsert_updates_existing() {
423        let (vs, _) = setup().await;
424        vs.ensure_collection("c", 4).await.unwrap();
425        vs.upsert(
426            "c",
427            vec![VectorPoint {
428                id: "p1".into(),
429                vector: vec![1.0, 0.0, 0.0, 0.0],
430                payload: HashMap::from([("v".into(), serde_json::json!(1))]),
431            }],
432        )
433        .await
434        .unwrap();
435        vs.upsert(
436            "c",
437            vec![VectorPoint {
438                id: "p1".into(),
439                vector: vec![0.0, 1.0, 0.0, 0.0],
440                payload: HashMap::from([("v".into(), serde_json::json!(2))]),
441            }],
442        )
443        .await
444        .unwrap();
445        let results = vs
446            .search("c", vec![0.0, 1.0, 0.0, 0.0], 1, None)
447            .await
448            .unwrap();
449        assert_eq!(results.len(), 1);
450        assert!((results[0].score - 1.0).abs() < 1e-5);
451    }
452
453    #[test]
454    fn cosine_similarity_import_wired() {
455        // Smoke test: verifies the re-export binding is intact. Edge-case coverage is in math.rs.
456        assert!(!cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).is_nan());
457    }
458
459    #[tokio::test]
460    async fn search_with_must_not_filter() {
461        let (vs, _) = setup().await;
462        vs.ensure_collection("c", 4).await.unwrap();
463        vs.upsert(
464            "c",
465            vec![
466                VectorPoint {
467                    id: "a".into(),
468                    vector: vec![1.0, 0.0, 0.0, 0.0],
469                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
470                },
471                VectorPoint {
472                    id: "b".into(),
473                    vector: vec![1.0, 0.0, 0.0, 0.0],
474                    payload: HashMap::from([("role".into(), serde_json::json!("system"))]),
475                },
476            ],
477        )
478        .await
479        .unwrap();
480
481        let filter = VectorFilter {
482            must: vec![],
483            must_not: vec![FieldCondition {
484                field: "role".into(),
485                value: FieldValue::Text("system".into()),
486            }],
487        };
488        let results = vs
489            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
490            .await
491            .unwrap();
492        assert_eq!(results.len(), 1);
493        assert_eq!(results[0].id, "a");
494    }
495
496    #[tokio::test]
497    async fn search_with_integer_filter() {
498        let (vs, _) = setup().await;
499        vs.ensure_collection("c", 4).await.unwrap();
500        vs.upsert(
501            "c",
502            vec![
503                VectorPoint {
504                    id: "a".into(),
505                    vector: vec![1.0, 0.0, 0.0, 0.0],
506                    payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
507                },
508                VectorPoint {
509                    id: "b".into(),
510                    vector: vec![1.0, 0.0, 0.0, 0.0],
511                    payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
512                },
513            ],
514        )
515        .await
516        .unwrap();
517
518        let filter = VectorFilter {
519            must: vec![FieldCondition {
520                field: "conv_id".into(),
521                value: FieldValue::Integer(1),
522            }],
523            must_not: vec![],
524        };
525        let results = vs
526            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
527            .await
528            .unwrap();
529        assert_eq!(results.len(), 1);
530        assert_eq!(results[0].id, "a");
531    }
532
533    #[tokio::test]
534    async fn search_empty_collection() {
535        let (vs, _) = setup().await;
536        vs.ensure_collection("c", 4).await.unwrap();
537        let results = vs
538            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
539            .await
540            .unwrap();
541        assert!(results.is_empty());
542    }
543
544    #[tokio::test]
545    async fn search_with_must_not_integer_filter() {
546        let (vs, _) = setup().await;
547        vs.ensure_collection("c", 4).await.unwrap();
548        vs.upsert(
549            "c",
550            vec![
551                VectorPoint {
552                    id: "a".into(),
553                    vector: vec![1.0, 0.0, 0.0, 0.0],
554                    payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
555                },
556                VectorPoint {
557                    id: "b".into(),
558                    vector: vec![1.0, 0.0, 0.0, 0.0],
559                    payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
560                },
561            ],
562        )
563        .await
564        .unwrap();
565
566        let filter = VectorFilter {
567            must: vec![],
568            must_not: vec![FieldCondition {
569                field: "conv_id".into(),
570                value: FieldValue::Integer(1),
571            }],
572        };
573        let results = vs
574            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
575            .await
576            .unwrap();
577        assert_eq!(results.len(), 1);
578        assert_eq!(results[0].id, "b");
579    }
580
581    #[tokio::test]
582    async fn search_with_combined_must_and_must_not() {
583        let (vs, _) = setup().await;
584        vs.ensure_collection("c", 4).await.unwrap();
585        vs.upsert(
586            "c",
587            vec![
588                VectorPoint {
589                    id: "a".into(),
590                    vector: vec![1.0, 0.0, 0.0, 0.0],
591                    payload: HashMap::from([
592                        ("role".into(), serde_json::json!("user")),
593                        ("conv_id".into(), serde_json::json!(1)),
594                    ]),
595                },
596                VectorPoint {
597                    id: "b".into(),
598                    vector: vec![1.0, 0.0, 0.0, 0.0],
599                    payload: HashMap::from([
600                        ("role".into(), serde_json::json!("user")),
601                        ("conv_id".into(), serde_json::json!(2)),
602                    ]),
603                },
604                VectorPoint {
605                    id: "c".into(),
606                    vector: vec![1.0, 0.0, 0.0, 0.0],
607                    payload: HashMap::from([
608                        ("role".into(), serde_json::json!("assistant")),
609                        ("conv_id".into(), serde_json::json!(1)),
610                    ]),
611                },
612            ],
613        )
614        .await
615        .unwrap();
616
617        let filter = VectorFilter {
618            must: vec![FieldCondition {
619                field: "role".into(),
620                value: FieldValue::Text("user".into()),
621            }],
622            must_not: vec![FieldCondition {
623                field: "conv_id".into(),
624                value: FieldValue::Integer(2),
625            }],
626        };
627        let results = vs
628            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
629            .await
630            .unwrap();
631        // Only "a": role=user AND conv_id != 2
632        assert_eq!(results.len(), 1);
633        assert_eq!(results[0].id, "a");
634    }
635
636    #[tokio::test]
637    async fn scroll_all_missing_key_field() {
638        let (vs, _) = setup().await;
639        vs.ensure_collection("c", 4).await.unwrap();
640        vs.upsert(
641            "c",
642            vec![VectorPoint {
643                id: "p1".into(),
644                vector: vec![1.0, 0.0, 0.0, 0.0],
645                payload: HashMap::from([("other".into(), serde_json::json!("value"))]),
646            }],
647        )
648        .await
649        .unwrap();
650        // key_field "text" doesn't exist in payload → point excluded from result
651        let result = vs.scroll_all("c", "text").await.unwrap();
652        assert!(
653            result.is_empty(),
654            "points without the key field must not appear in scroll result"
655        );
656    }
657
658    #[tokio::test]
659    async fn delete_by_ids_empty_and_nonexistent() {
660        let (vs, _) = setup().await;
661        vs.ensure_collection("c", 4).await.unwrap();
662        vs.upsert(
663            "c",
664            vec![VectorPoint {
665                id: "a".into(),
666                vector: vec![1.0, 0.0, 0.0, 0.0],
667                payload: HashMap::new(),
668            }],
669        )
670        .await
671        .unwrap();
672
673        // Empty list: no-op, must succeed
674        vs.delete_by_ids("c", vec![]).await.unwrap();
675
676        // Non-existent id: must succeed (idempotent)
677        vs.delete_by_ids("c", vec!["nonexistent".into()])
678            .await
679            .unwrap();
680
681        // Original point still present
682        let results = vs
683            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
684            .await
685            .unwrap();
686        assert_eq!(results.len(), 1);
687        assert_eq!(results[0].id, "a");
688    }
689
690    #[tokio::test]
691    async fn search_corrupt_blob_skipped() {
692        let (vs, store) = setup().await;
693        vs.ensure_collection("c", 4).await.unwrap();
694
695        // Insert a valid point first
696        vs.upsert(
697            "c",
698            vec![VectorPoint {
699                id: "valid".into(),
700                vector: vec![1.0, 0.0, 0.0, 0.0],
701                payload: HashMap::new(),
702            }],
703        )
704        .await
705        .unwrap();
706
707        // Insert raw invalid bytes directly into vector_points table
708        // 3 bytes cannot be cast to f32 (needs multiples of 4)
709        let corrupt_blob: Vec<u8> = vec![0xFF, 0xFE, 0xFD];
710        let payload_json = r"{}";
711        zeph_db::query(sql!(
712            "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?)"
713        ))
714        .bind("corrupt")
715        .bind("c")
716        .bind(&corrupt_blob)
717        .bind(payload_json)
718        .execute(store.pool())
719        .await
720        .unwrap();
721
722        // Search must not panic and must skip the corrupt point
723        let results = vs
724            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
725            .await
726            .unwrap();
727        assert_eq!(results.len(), 1);
728        assert_eq!(results[0].id, "valid");
729    }
730}