Skip to main content

zeph_memory/
sqlite_vector_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5
6use sqlx::SqlitePool;
7
8use crate::vector_store::{
9    BoxFuture, FieldValue, ScoredVectorPoint, ScrollResult, VectorFilter, VectorPoint, VectorStore,
10    VectorStoreError,
11};
12
13pub struct SqliteVectorStore {
14    pool: SqlitePool,
15}
16
17impl SqliteVectorStore {
18    #[must_use]
19    pub fn new(pool: SqlitePool) -> Self {
20        Self { pool }
21    }
22}
23
24use crate::math::cosine_similarity;
25
26fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
27    for cond in &filter.must {
28        let Some(val) = payload.get(&cond.field) else {
29            return false;
30        };
31        let matches = match &cond.value {
32            FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
33            FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
34        };
35        if !matches {
36            return false;
37        }
38    }
39    for cond in &filter.must_not {
40        let Some(val) = payload.get(&cond.field) else {
41            continue;
42        };
43        let matches = match &cond.value {
44            FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
45            FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
46        };
47        if matches {
48            return false;
49        }
50    }
51    true
52}
53
54impl VectorStore for SqliteVectorStore {
55    fn ensure_collection(
56        &self,
57        collection: &str,
58        _vector_size: u64,
59    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
60        let collection = collection.to_owned();
61        Box::pin(async move {
62            sqlx::query("INSERT OR IGNORE INTO vector_collections (name) VALUES (?)")
63                .bind(&collection)
64                .execute(&self.pool)
65                .await
66                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
67            Ok(())
68        })
69    }
70
71    fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
72        let collection = collection.to_owned();
73        Box::pin(async move {
74            let row: (i64,) =
75                sqlx::query_as("SELECT COUNT(*) FROM vector_collections WHERE name = ?")
76                    .bind(&collection)
77                    .fetch_one(&self.pool)
78                    .await
79                    .map_err(|e| VectorStoreError::Connection(e.to_string()))?;
80            Ok(row.0 > 0)
81        })
82    }
83
84    fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
85        let collection = collection.to_owned();
86        Box::pin(async move {
87            sqlx::query("DELETE FROM vector_points WHERE collection = ?")
88                .bind(&collection)
89                .execute(&self.pool)
90                .await
91                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
92            sqlx::query("DELETE FROM vector_collections WHERE name = ?")
93                .bind(&collection)
94                .execute(&self.pool)
95                .await
96                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
97            Ok(())
98        })
99    }
100
101    fn upsert(
102        &self,
103        collection: &str,
104        points: Vec<VectorPoint>,
105    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
106        let collection = collection.to_owned();
107        Box::pin(async move {
108            for point in points {
109                let vector_bytes: Vec<u8> = bytemuck::cast_slice(&point.vector).to_vec();
110                let payload_json = serde_json::to_string(&point.payload)
111                    .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
112                sqlx::query(
113                    "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?) \
114                     ON CONFLICT(collection, id) DO UPDATE SET vector = excluded.vector, payload = excluded.payload",
115                )
116                .bind(&point.id)
117                .bind(&collection)
118                .bind(&vector_bytes)
119                .bind(&payload_json)
120                .execute(&self.pool)
121                .await
122                .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
123            }
124            Ok(())
125        })
126    }
127
128    fn search(
129        &self,
130        collection: &str,
131        vector: Vec<f32>,
132        limit: u64,
133        filter: Option<VectorFilter>,
134    ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
135        let collection = collection.to_owned();
136        Box::pin(async move {
137            let rows: Vec<(String, Vec<u8>, String)> = sqlx::query_as(
138                "SELECT id, vector, payload FROM vector_points WHERE collection = ?",
139            )
140            .bind(&collection)
141            .fetch_all(&self.pool)
142            .await
143            .map_err(|e| VectorStoreError::Search(e.to_string()))?;
144
145            let limit_usize = usize::try_from(limit).unwrap_or(usize::MAX);
146            let mut scored: Vec<ScoredVectorPoint> = rows
147                .into_iter()
148                .filter_map(|(id, blob, payload_str)| {
149                    let Ok(stored) = bytemuck::try_cast_slice::<u8, f32>(&blob) else {
150                        return None;
151                    };
152                    let payload: HashMap<String, serde_json::Value> =
153                        serde_json::from_str(&payload_str).unwrap_or_default();
154
155                    if filter
156                        .as_ref()
157                        .is_some_and(|f| !matches_filter(&payload, f))
158                    {
159                        return None;
160                    }
161
162                    let score = cosine_similarity(&vector, stored);
163                    Some(ScoredVectorPoint { id, score, payload })
164                })
165                .collect();
166
167            scored.sort_by(|a, b| {
168                b.score
169                    .partial_cmp(&a.score)
170                    .unwrap_or(std::cmp::Ordering::Equal)
171            });
172            scored.truncate(limit_usize);
173            Ok(scored)
174        })
175    }
176
177    fn delete_by_ids(
178        &self,
179        collection: &str,
180        ids: Vec<String>,
181    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
182        let collection = collection.to_owned();
183        Box::pin(async move {
184            for id in ids {
185                sqlx::query("DELETE FROM vector_points WHERE collection = ? AND id = ?")
186                    .bind(&collection)
187                    .bind(&id)
188                    .execute(&self.pool)
189                    .await
190                    .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
191            }
192            Ok(())
193        })
194    }
195
196    fn scroll_all(
197        &self,
198        collection: &str,
199        key_field: &str,
200    ) -> BoxFuture<'_, Result<ScrollResult, VectorStoreError>> {
201        let collection = collection.to_owned();
202        let key_field = key_field.to_owned();
203        Box::pin(async move {
204            let rows: Vec<(String, String)> =
205                sqlx::query_as("SELECT id, payload FROM vector_points WHERE collection = ?")
206                    .bind(&collection)
207                    .fetch_all(&self.pool)
208                    .await
209                    .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
210
211            let mut result = ScrollResult::new();
212            for (id, payload_str) in rows {
213                let payload: HashMap<String, serde_json::Value> =
214                    serde_json::from_str(&payload_str).unwrap_or_default();
215                if let Some(val) = payload.get(&key_field) {
216                    let mut map = HashMap::new();
217                    map.insert(
218                        key_field.clone(),
219                        val.as_str().unwrap_or_default().to_owned(),
220                    );
221                    result.insert(id, map);
222                }
223            }
224            Ok(result)
225        })
226    }
227
228    fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
229        Box::pin(async move {
230            sqlx::query_scalar::<_, i32>("SELECT 1")
231                .fetch_one(&self.pool)
232                .await
233                .map(|_| true)
234                .map_err(|e| VectorStoreError::Collection(e.to_string()))
235        })
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::sqlite::SqliteStore;
243    use crate::vector_store::FieldCondition;
244
245    async fn setup() -> (SqliteVectorStore, SqliteStore) {
246        let store = SqliteStore::new(":memory:").await.unwrap();
247        let pool = store.pool().clone();
248        let vs = SqliteVectorStore::new(pool);
249        (vs, store)
250    }
251
252    #[tokio::test]
253    async fn ensure_and_exists() {
254        let (vs, _) = setup().await;
255        assert!(!vs.collection_exists("col1").await.unwrap());
256        vs.ensure_collection("col1", 4).await.unwrap();
257        assert!(vs.collection_exists("col1").await.unwrap());
258        // idempotent
259        vs.ensure_collection("col1", 4).await.unwrap();
260        assert!(vs.collection_exists("col1").await.unwrap());
261    }
262
263    #[tokio::test]
264    async fn delete_collection() {
265        let (vs, _) = setup().await;
266        vs.ensure_collection("col1", 4).await.unwrap();
267        vs.upsert(
268            "col1",
269            vec![VectorPoint {
270                id: "p1".into(),
271                vector: vec![1.0, 0.0, 0.0, 0.0],
272                payload: HashMap::new(),
273            }],
274        )
275        .await
276        .unwrap();
277        vs.delete_collection("col1").await.unwrap();
278        assert!(!vs.collection_exists("col1").await.unwrap());
279    }
280
281    #[tokio::test]
282    async fn upsert_and_search() {
283        let (vs, _) = setup().await;
284        vs.ensure_collection("c", 4).await.unwrap();
285        vs.upsert(
286            "c",
287            vec![
288                VectorPoint {
289                    id: "a".into(),
290                    vector: vec![1.0, 0.0, 0.0, 0.0],
291                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
292                },
293                VectorPoint {
294                    id: "b".into(),
295                    vector: vec![0.0, 1.0, 0.0, 0.0],
296                    payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
297                },
298            ],
299        )
300        .await
301        .unwrap();
302
303        let results = vs
304            .search("c", vec![1.0, 0.0, 0.0, 0.0], 2, None)
305            .await
306            .unwrap();
307        assert_eq!(results.len(), 2);
308        assert_eq!(results[0].id, "a");
309        assert!((results[0].score - 1.0).abs() < 1e-5);
310    }
311
312    #[tokio::test]
313    async fn search_with_filter() {
314        let (vs, _) = setup().await;
315        vs.ensure_collection("c", 4).await.unwrap();
316        vs.upsert(
317            "c",
318            vec![
319                VectorPoint {
320                    id: "a".into(),
321                    vector: vec![1.0, 0.0, 0.0, 0.0],
322                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
323                },
324                VectorPoint {
325                    id: "b".into(),
326                    vector: vec![1.0, 0.0, 0.0, 0.0],
327                    payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
328                },
329            ],
330        )
331        .await
332        .unwrap();
333
334        let filter = VectorFilter {
335            must: vec![FieldCondition {
336                field: "role".into(),
337                value: FieldValue::Text("user".into()),
338            }],
339            must_not: vec![],
340        };
341        let results = vs
342            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
343            .await
344            .unwrap();
345        assert_eq!(results.len(), 1);
346        assert_eq!(results[0].id, "a");
347    }
348
349    #[tokio::test]
350    async fn delete_by_ids() {
351        let (vs, _) = setup().await;
352        vs.ensure_collection("c", 4).await.unwrap();
353        vs.upsert(
354            "c",
355            vec![
356                VectorPoint {
357                    id: "a".into(),
358                    vector: vec![1.0, 0.0, 0.0, 0.0],
359                    payload: HashMap::new(),
360                },
361                VectorPoint {
362                    id: "b".into(),
363                    vector: vec![0.0, 1.0, 0.0, 0.0],
364                    payload: HashMap::new(),
365                },
366            ],
367        )
368        .await
369        .unwrap();
370        vs.delete_by_ids("c", vec!["a".into()]).await.unwrap();
371        let results = vs
372            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
373            .await
374            .unwrap();
375        assert_eq!(results.len(), 1);
376        assert_eq!(results[0].id, "b");
377    }
378
379    #[tokio::test]
380    async fn scroll_all() {
381        let (vs, _) = setup().await;
382        vs.ensure_collection("c", 4).await.unwrap();
383        vs.upsert(
384            "c",
385            vec![VectorPoint {
386                id: "p1".into(),
387                vector: vec![1.0, 0.0, 0.0, 0.0],
388                payload: HashMap::from([("text".into(), serde_json::json!("hello"))]),
389            }],
390        )
391        .await
392        .unwrap();
393        let result = vs.scroll_all("c", "text").await.unwrap();
394        assert_eq!(result.len(), 1);
395        assert_eq!(result["p1"]["text"], "hello");
396    }
397
398    #[tokio::test]
399    async fn upsert_updates_existing() {
400        let (vs, _) = setup().await;
401        vs.ensure_collection("c", 4).await.unwrap();
402        vs.upsert(
403            "c",
404            vec![VectorPoint {
405                id: "p1".into(),
406                vector: vec![1.0, 0.0, 0.0, 0.0],
407                payload: HashMap::from([("v".into(), serde_json::json!(1))]),
408            }],
409        )
410        .await
411        .unwrap();
412        vs.upsert(
413            "c",
414            vec![VectorPoint {
415                id: "p1".into(),
416                vector: vec![0.0, 1.0, 0.0, 0.0],
417                payload: HashMap::from([("v".into(), serde_json::json!(2))]),
418            }],
419        )
420        .await
421        .unwrap();
422        let results = vs
423            .search("c", vec![0.0, 1.0, 0.0, 0.0], 1, None)
424            .await
425            .unwrap();
426        assert_eq!(results.len(), 1);
427        assert!((results[0].score - 1.0).abs() < 1e-5);
428    }
429
430    #[test]
431    fn cosine_similarity_import_wired() {
432        // Smoke test: verifies the re-export binding is intact. Edge-case coverage is in math.rs.
433        assert!(!cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).is_nan());
434    }
435
436    #[tokio::test]
437    async fn search_with_must_not_filter() {
438        let (vs, _) = setup().await;
439        vs.ensure_collection("c", 4).await.unwrap();
440        vs.upsert(
441            "c",
442            vec![
443                VectorPoint {
444                    id: "a".into(),
445                    vector: vec![1.0, 0.0, 0.0, 0.0],
446                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
447                },
448                VectorPoint {
449                    id: "b".into(),
450                    vector: vec![1.0, 0.0, 0.0, 0.0],
451                    payload: HashMap::from([("role".into(), serde_json::json!("system"))]),
452                },
453            ],
454        )
455        .await
456        .unwrap();
457
458        let filter = VectorFilter {
459            must: vec![],
460            must_not: vec![FieldCondition {
461                field: "role".into(),
462                value: FieldValue::Text("system".into()),
463            }],
464        };
465        let results = vs
466            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
467            .await
468            .unwrap();
469        assert_eq!(results.len(), 1);
470        assert_eq!(results[0].id, "a");
471    }
472
473    #[tokio::test]
474    async fn search_with_integer_filter() {
475        let (vs, _) = setup().await;
476        vs.ensure_collection("c", 4).await.unwrap();
477        vs.upsert(
478            "c",
479            vec![
480                VectorPoint {
481                    id: "a".into(),
482                    vector: vec![1.0, 0.0, 0.0, 0.0],
483                    payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
484                },
485                VectorPoint {
486                    id: "b".into(),
487                    vector: vec![1.0, 0.0, 0.0, 0.0],
488                    payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
489                },
490            ],
491        )
492        .await
493        .unwrap();
494
495        let filter = VectorFilter {
496            must: vec![FieldCondition {
497                field: "conv_id".into(),
498                value: FieldValue::Integer(1),
499            }],
500            must_not: vec![],
501        };
502        let results = vs
503            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
504            .await
505            .unwrap();
506        assert_eq!(results.len(), 1);
507        assert_eq!(results[0].id, "a");
508    }
509
510    #[tokio::test]
511    async fn search_empty_collection() {
512        let (vs, _) = setup().await;
513        vs.ensure_collection("c", 4).await.unwrap();
514        let results = vs
515            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
516            .await
517            .unwrap();
518        assert!(results.is_empty());
519    }
520
521    #[tokio::test]
522    async fn search_with_must_not_integer_filter() {
523        let (vs, _) = setup().await;
524        vs.ensure_collection("c", 4).await.unwrap();
525        vs.upsert(
526            "c",
527            vec![
528                VectorPoint {
529                    id: "a".into(),
530                    vector: vec![1.0, 0.0, 0.0, 0.0],
531                    payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
532                },
533                VectorPoint {
534                    id: "b".into(),
535                    vector: vec![1.0, 0.0, 0.0, 0.0],
536                    payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
537                },
538            ],
539        )
540        .await
541        .unwrap();
542
543        let filter = VectorFilter {
544            must: vec![],
545            must_not: vec![FieldCondition {
546                field: "conv_id".into(),
547                value: FieldValue::Integer(1),
548            }],
549        };
550        let results = vs
551            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
552            .await
553            .unwrap();
554        assert_eq!(results.len(), 1);
555        assert_eq!(results[0].id, "b");
556    }
557
558    #[tokio::test]
559    async fn search_with_combined_must_and_must_not() {
560        let (vs, _) = setup().await;
561        vs.ensure_collection("c", 4).await.unwrap();
562        vs.upsert(
563            "c",
564            vec![
565                VectorPoint {
566                    id: "a".into(),
567                    vector: vec![1.0, 0.0, 0.0, 0.0],
568                    payload: HashMap::from([
569                        ("role".into(), serde_json::json!("user")),
570                        ("conv_id".into(), serde_json::json!(1)),
571                    ]),
572                },
573                VectorPoint {
574                    id: "b".into(),
575                    vector: vec![1.0, 0.0, 0.0, 0.0],
576                    payload: HashMap::from([
577                        ("role".into(), serde_json::json!("user")),
578                        ("conv_id".into(), serde_json::json!(2)),
579                    ]),
580                },
581                VectorPoint {
582                    id: "c".into(),
583                    vector: vec![1.0, 0.0, 0.0, 0.0],
584                    payload: HashMap::from([
585                        ("role".into(), serde_json::json!("assistant")),
586                        ("conv_id".into(), serde_json::json!(1)),
587                    ]),
588                },
589            ],
590        )
591        .await
592        .unwrap();
593
594        let filter = VectorFilter {
595            must: vec![FieldCondition {
596                field: "role".into(),
597                value: FieldValue::Text("user".into()),
598            }],
599            must_not: vec![FieldCondition {
600                field: "conv_id".into(),
601                value: FieldValue::Integer(2),
602            }],
603        };
604        let results = vs
605            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
606            .await
607            .unwrap();
608        // Only "a": role=user AND conv_id != 2
609        assert_eq!(results.len(), 1);
610        assert_eq!(results[0].id, "a");
611    }
612
613    #[tokio::test]
614    async fn scroll_all_missing_key_field() {
615        let (vs, _) = setup().await;
616        vs.ensure_collection("c", 4).await.unwrap();
617        vs.upsert(
618            "c",
619            vec![VectorPoint {
620                id: "p1".into(),
621                vector: vec![1.0, 0.0, 0.0, 0.0],
622                payload: HashMap::from([("other".into(), serde_json::json!("value"))]),
623            }],
624        )
625        .await
626        .unwrap();
627        // key_field "text" doesn't exist in payload → point excluded from result
628        let result = vs.scroll_all("c", "text").await.unwrap();
629        assert!(
630            result.is_empty(),
631            "points without the key field must not appear in scroll result"
632        );
633    }
634
635    #[tokio::test]
636    async fn delete_by_ids_empty_and_nonexistent() {
637        let (vs, _) = setup().await;
638        vs.ensure_collection("c", 4).await.unwrap();
639        vs.upsert(
640            "c",
641            vec![VectorPoint {
642                id: "a".into(),
643                vector: vec![1.0, 0.0, 0.0, 0.0],
644                payload: HashMap::new(),
645            }],
646        )
647        .await
648        .unwrap();
649
650        // Empty list: no-op, must succeed
651        vs.delete_by_ids("c", vec![]).await.unwrap();
652
653        // Non-existent id: must succeed (idempotent)
654        vs.delete_by_ids("c", vec!["nonexistent".into()])
655            .await
656            .unwrap();
657
658        // Original point still present
659        let results = vs
660            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
661            .await
662            .unwrap();
663        assert_eq!(results.len(), 1);
664        assert_eq!(results[0].id, "a");
665    }
666
667    #[tokio::test]
668    async fn search_corrupt_blob_skipped() {
669        let (vs, store) = setup().await;
670        vs.ensure_collection("c", 4).await.unwrap();
671
672        // Insert a valid point first
673        vs.upsert(
674            "c",
675            vec![VectorPoint {
676                id: "valid".into(),
677                vector: vec![1.0, 0.0, 0.0, 0.0],
678                payload: HashMap::new(),
679            }],
680        )
681        .await
682        .unwrap();
683
684        // Insert raw invalid bytes directly into vector_points table
685        // 3 bytes cannot be cast to f32 (needs multiples of 4)
686        let corrupt_blob: Vec<u8> = vec![0xFF, 0xFE, 0xFD];
687        let payload_json = r"{}";
688        sqlx::query(
689            "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?)",
690        )
691        .bind("corrupt")
692        .bind("c")
693        .bind(&corrupt_blob)
694        .bind(payload_json)
695        .execute(store.pool())
696        .await
697        .unwrap();
698
699        // Search must not panic and must skip the corrupt point
700        let results = vs
701            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
702            .await
703            .unwrap();
704        assert_eq!(results.len(), 1);
705        assert_eq!(results[0].id, "valid");
706    }
707}