Skip to main content

zeph_memory/
sqlite_vector_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7
8use sqlx::SqlitePool;
9
10use crate::vector_store::{
11    FieldValue, ScoredVectorPoint, ScrollResult, VectorFilter, VectorPoint, VectorStore,
12    VectorStoreError,
13};
14
15pub struct SqliteVectorStore {
16    pool: SqlitePool,
17}
18
19impl SqliteVectorStore {
20    #[must_use]
21    pub fn new(pool: SqlitePool) -> Self {
22        Self { pool }
23    }
24}
25
26use crate::math::cosine_similarity;
27
28fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
29    for cond in &filter.must {
30        let Some(val) = payload.get(&cond.field) else {
31            return false;
32        };
33        let matches = match &cond.value {
34            FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
35            FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
36        };
37        if !matches {
38            return false;
39        }
40    }
41    for cond in &filter.must_not {
42        let Some(val) = payload.get(&cond.field) else {
43            continue;
44        };
45        let matches = match &cond.value {
46            FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
47            FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
48        };
49        if matches {
50            return false;
51        }
52    }
53    true
54}
55
56type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
57
58impl VectorStore for SqliteVectorStore {
59    fn ensure_collection(
60        &self,
61        collection: &str,
62        _vector_size: u64,
63    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
64        let collection = collection.to_owned();
65        Box::pin(async move {
66            sqlx::query("INSERT OR IGNORE INTO vector_collections (name) VALUES (?)")
67                .bind(&collection)
68                .execute(&self.pool)
69                .await
70                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
71            Ok(())
72        })
73    }
74
75    fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
76        let collection = collection.to_owned();
77        Box::pin(async move {
78            let row: (i64,) =
79                sqlx::query_as("SELECT COUNT(*) FROM vector_collections WHERE name = ?")
80                    .bind(&collection)
81                    .fetch_one(&self.pool)
82                    .await
83                    .map_err(|e| VectorStoreError::Connection(e.to_string()))?;
84            Ok(row.0 > 0)
85        })
86    }
87
88    fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
89        let collection = collection.to_owned();
90        Box::pin(async move {
91            sqlx::query("DELETE FROM vector_points WHERE collection = ?")
92                .bind(&collection)
93                .execute(&self.pool)
94                .await
95                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
96            sqlx::query("DELETE FROM vector_collections WHERE name = ?")
97                .bind(&collection)
98                .execute(&self.pool)
99                .await
100                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
101            Ok(())
102        })
103    }
104
105    fn upsert(
106        &self,
107        collection: &str,
108        points: Vec<VectorPoint>,
109    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
110        let collection = collection.to_owned();
111        Box::pin(async move {
112            for point in points {
113                let vector_bytes: Vec<u8> = bytemuck::cast_slice(&point.vector).to_vec();
114                let payload_json = serde_json::to_string(&point.payload)
115                    .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
116                sqlx::query(
117                    "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?) \
118                     ON CONFLICT(collection, id) DO UPDATE SET vector = excluded.vector, payload = excluded.payload",
119                )
120                .bind(&point.id)
121                .bind(&collection)
122                .bind(&vector_bytes)
123                .bind(&payload_json)
124                .execute(&self.pool)
125                .await
126                .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
127            }
128            Ok(())
129        })
130    }
131
132    fn search(
133        &self,
134        collection: &str,
135        vector: Vec<f32>,
136        limit: u64,
137        filter: Option<VectorFilter>,
138    ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
139        let collection = collection.to_owned();
140        Box::pin(async move {
141            let rows: Vec<(String, Vec<u8>, String)> = sqlx::query_as(
142                "SELECT id, vector, payload FROM vector_points WHERE collection = ?",
143            )
144            .bind(&collection)
145            .fetch_all(&self.pool)
146            .await
147            .map_err(|e| VectorStoreError::Search(e.to_string()))?;
148
149            let limit_usize = usize::try_from(limit).unwrap_or(usize::MAX);
150            let mut scored: Vec<ScoredVectorPoint> = rows
151                .into_iter()
152                .filter_map(|(id, blob, payload_str)| {
153                    let Ok(stored) = bytemuck::try_cast_slice::<u8, f32>(&blob) else {
154                        return None;
155                    };
156                    let payload: HashMap<String, serde_json::Value> =
157                        serde_json::from_str(&payload_str).unwrap_or_default();
158
159                    if filter
160                        .as_ref()
161                        .is_some_and(|f| !matches_filter(&payload, f))
162                    {
163                        return None;
164                    }
165
166                    let score = cosine_similarity(&vector, stored);
167                    Some(ScoredVectorPoint { id, score, payload })
168                })
169                .collect();
170
171            scored.sort_by(|a, b| {
172                b.score
173                    .partial_cmp(&a.score)
174                    .unwrap_or(std::cmp::Ordering::Equal)
175            });
176            scored.truncate(limit_usize);
177            Ok(scored)
178        })
179    }
180
181    fn delete_by_ids(
182        &self,
183        collection: &str,
184        ids: Vec<String>,
185    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
186        let collection = collection.to_owned();
187        Box::pin(async move {
188            for id in ids {
189                sqlx::query("DELETE FROM vector_points WHERE collection = ? AND id = ?")
190                    .bind(&collection)
191                    .bind(&id)
192                    .execute(&self.pool)
193                    .await
194                    .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
195            }
196            Ok(())
197        })
198    }
199
200    fn scroll_all(
201        &self,
202        collection: &str,
203        key_field: &str,
204    ) -> BoxFuture<'_, Result<ScrollResult, VectorStoreError>> {
205        let collection = collection.to_owned();
206        let key_field = key_field.to_owned();
207        Box::pin(async move {
208            let rows: Vec<(String, String)> =
209                sqlx::query_as("SELECT id, payload FROM vector_points WHERE collection = ?")
210                    .bind(&collection)
211                    .fetch_all(&self.pool)
212                    .await
213                    .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
214
215            let mut result = ScrollResult::new();
216            for (id, payload_str) in rows {
217                let payload: HashMap<String, serde_json::Value> =
218                    serde_json::from_str(&payload_str).unwrap_or_default();
219                if let Some(val) = payload.get(&key_field) {
220                    let mut map = HashMap::new();
221                    map.insert(
222                        key_field.clone(),
223                        val.as_str().unwrap_or_default().to_owned(),
224                    );
225                    result.insert(id, map);
226                }
227            }
228            Ok(result)
229        })
230    }
231
232    fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
233        Box::pin(async move {
234            sqlx::query_scalar::<_, i32>("SELECT 1")
235                .fetch_one(&self.pool)
236                .await
237                .map(|_| true)
238                .map_err(|e| VectorStoreError::Collection(e.to_string()))
239        })
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::sqlite::SqliteStore;
247    use crate::vector_store::FieldCondition;
248
249    async fn setup() -> (SqliteVectorStore, SqliteStore) {
250        let store = SqliteStore::new(":memory:").await.unwrap();
251        let pool = store.pool().clone();
252        let vs = SqliteVectorStore::new(pool);
253        (vs, store)
254    }
255
256    #[tokio::test]
257    async fn ensure_and_exists() {
258        let (vs, _) = setup().await;
259        assert!(!vs.collection_exists("col1").await.unwrap());
260        vs.ensure_collection("col1", 4).await.unwrap();
261        assert!(vs.collection_exists("col1").await.unwrap());
262        // idempotent
263        vs.ensure_collection("col1", 4).await.unwrap();
264        assert!(vs.collection_exists("col1").await.unwrap());
265    }
266
267    #[tokio::test]
268    async fn delete_collection() {
269        let (vs, _) = setup().await;
270        vs.ensure_collection("col1", 4).await.unwrap();
271        vs.upsert(
272            "col1",
273            vec![VectorPoint {
274                id: "p1".into(),
275                vector: vec![1.0, 0.0, 0.0, 0.0],
276                payload: HashMap::new(),
277            }],
278        )
279        .await
280        .unwrap();
281        vs.delete_collection("col1").await.unwrap();
282        assert!(!vs.collection_exists("col1").await.unwrap());
283    }
284
285    #[tokio::test]
286    async fn upsert_and_search() {
287        let (vs, _) = setup().await;
288        vs.ensure_collection("c", 4).await.unwrap();
289        vs.upsert(
290            "c",
291            vec![
292                VectorPoint {
293                    id: "a".into(),
294                    vector: vec![1.0, 0.0, 0.0, 0.0],
295                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
296                },
297                VectorPoint {
298                    id: "b".into(),
299                    vector: vec![0.0, 1.0, 0.0, 0.0],
300                    payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
301                },
302            ],
303        )
304        .await
305        .unwrap();
306
307        let results = vs
308            .search("c", vec![1.0, 0.0, 0.0, 0.0], 2, None)
309            .await
310            .unwrap();
311        assert_eq!(results.len(), 2);
312        assert_eq!(results[0].id, "a");
313        assert!((results[0].score - 1.0).abs() < 1e-5);
314    }
315
316    #[tokio::test]
317    async fn search_with_filter() {
318        let (vs, _) = setup().await;
319        vs.ensure_collection("c", 4).await.unwrap();
320        vs.upsert(
321            "c",
322            vec![
323                VectorPoint {
324                    id: "a".into(),
325                    vector: vec![1.0, 0.0, 0.0, 0.0],
326                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
327                },
328                VectorPoint {
329                    id: "b".into(),
330                    vector: vec![1.0, 0.0, 0.0, 0.0],
331                    payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
332                },
333            ],
334        )
335        .await
336        .unwrap();
337
338        let filter = VectorFilter {
339            must: vec![FieldCondition {
340                field: "role".into(),
341                value: FieldValue::Text("user".into()),
342            }],
343            must_not: vec![],
344        };
345        let results = vs
346            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
347            .await
348            .unwrap();
349        assert_eq!(results.len(), 1);
350        assert_eq!(results[0].id, "a");
351    }
352
353    #[tokio::test]
354    async fn delete_by_ids() {
355        let (vs, _) = setup().await;
356        vs.ensure_collection("c", 4).await.unwrap();
357        vs.upsert(
358            "c",
359            vec![
360                VectorPoint {
361                    id: "a".into(),
362                    vector: vec![1.0, 0.0, 0.0, 0.0],
363                    payload: HashMap::new(),
364                },
365                VectorPoint {
366                    id: "b".into(),
367                    vector: vec![0.0, 1.0, 0.0, 0.0],
368                    payload: HashMap::new(),
369                },
370            ],
371        )
372        .await
373        .unwrap();
374        vs.delete_by_ids("c", vec!["a".into()]).await.unwrap();
375        let results = vs
376            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
377            .await
378            .unwrap();
379        assert_eq!(results.len(), 1);
380        assert_eq!(results[0].id, "b");
381    }
382
383    #[tokio::test]
384    async fn scroll_all() {
385        let (vs, _) = setup().await;
386        vs.ensure_collection("c", 4).await.unwrap();
387        vs.upsert(
388            "c",
389            vec![VectorPoint {
390                id: "p1".into(),
391                vector: vec![1.0, 0.0, 0.0, 0.0],
392                payload: HashMap::from([("text".into(), serde_json::json!("hello"))]),
393            }],
394        )
395        .await
396        .unwrap();
397        let result = vs.scroll_all("c", "text").await.unwrap();
398        assert_eq!(result.len(), 1);
399        assert_eq!(result["p1"]["text"], "hello");
400    }
401
402    #[tokio::test]
403    async fn upsert_updates_existing() {
404        let (vs, _) = setup().await;
405        vs.ensure_collection("c", 4).await.unwrap();
406        vs.upsert(
407            "c",
408            vec![VectorPoint {
409                id: "p1".into(),
410                vector: vec![1.0, 0.0, 0.0, 0.0],
411                payload: HashMap::from([("v".into(), serde_json::json!(1))]),
412            }],
413        )
414        .await
415        .unwrap();
416        vs.upsert(
417            "c",
418            vec![VectorPoint {
419                id: "p1".into(),
420                vector: vec![0.0, 1.0, 0.0, 0.0],
421                payload: HashMap::from([("v".into(), serde_json::json!(2))]),
422            }],
423        )
424        .await
425        .unwrap();
426        let results = vs
427            .search("c", vec![0.0, 1.0, 0.0, 0.0], 1, None)
428            .await
429            .unwrap();
430        assert_eq!(results.len(), 1);
431        assert!((results[0].score - 1.0).abs() < 1e-5);
432    }
433
434    #[test]
435    fn cosine_similarity_import_wired() {
436        // Smoke test: verifies the re-export binding is intact. Edge-case coverage is in math.rs.
437        assert!(!cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).is_nan());
438    }
439
440    #[tokio::test]
441    async fn search_with_must_not_filter() {
442        let (vs, _) = setup().await;
443        vs.ensure_collection("c", 4).await.unwrap();
444        vs.upsert(
445            "c",
446            vec![
447                VectorPoint {
448                    id: "a".into(),
449                    vector: vec![1.0, 0.0, 0.0, 0.0],
450                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
451                },
452                VectorPoint {
453                    id: "b".into(),
454                    vector: vec![1.0, 0.0, 0.0, 0.0],
455                    payload: HashMap::from([("role".into(), serde_json::json!("system"))]),
456                },
457            ],
458        )
459        .await
460        .unwrap();
461
462        let filter = VectorFilter {
463            must: vec![],
464            must_not: vec![FieldCondition {
465                field: "role".into(),
466                value: FieldValue::Text("system".into()),
467            }],
468        };
469        let results = vs
470            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
471            .await
472            .unwrap();
473        assert_eq!(results.len(), 1);
474        assert_eq!(results[0].id, "a");
475    }
476
477    #[tokio::test]
478    async fn search_with_integer_filter() {
479        let (vs, _) = setup().await;
480        vs.ensure_collection("c", 4).await.unwrap();
481        vs.upsert(
482            "c",
483            vec![
484                VectorPoint {
485                    id: "a".into(),
486                    vector: vec![1.0, 0.0, 0.0, 0.0],
487                    payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
488                },
489                VectorPoint {
490                    id: "b".into(),
491                    vector: vec![1.0, 0.0, 0.0, 0.0],
492                    payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
493                },
494            ],
495        )
496        .await
497        .unwrap();
498
499        let filter = VectorFilter {
500            must: vec![FieldCondition {
501                field: "conv_id".into(),
502                value: FieldValue::Integer(1),
503            }],
504            must_not: vec![],
505        };
506        let results = vs
507            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
508            .await
509            .unwrap();
510        assert_eq!(results.len(), 1);
511        assert_eq!(results[0].id, "a");
512    }
513
514    #[tokio::test]
515    async fn search_empty_collection() {
516        let (vs, _) = setup().await;
517        vs.ensure_collection("c", 4).await.unwrap();
518        let results = vs
519            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
520            .await
521            .unwrap();
522        assert!(results.is_empty());
523    }
524
525    #[tokio::test]
526    async fn search_with_must_not_integer_filter() {
527        let (vs, _) = setup().await;
528        vs.ensure_collection("c", 4).await.unwrap();
529        vs.upsert(
530            "c",
531            vec![
532                VectorPoint {
533                    id: "a".into(),
534                    vector: vec![1.0, 0.0, 0.0, 0.0],
535                    payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
536                },
537                VectorPoint {
538                    id: "b".into(),
539                    vector: vec![1.0, 0.0, 0.0, 0.0],
540                    payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
541                },
542            ],
543        )
544        .await
545        .unwrap();
546
547        let filter = VectorFilter {
548            must: vec![],
549            must_not: vec![FieldCondition {
550                field: "conv_id".into(),
551                value: FieldValue::Integer(1),
552            }],
553        };
554        let results = vs
555            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
556            .await
557            .unwrap();
558        assert_eq!(results.len(), 1);
559        assert_eq!(results[0].id, "b");
560    }
561
562    #[tokio::test]
563    async fn search_with_combined_must_and_must_not() {
564        let (vs, _) = setup().await;
565        vs.ensure_collection("c", 4).await.unwrap();
566        vs.upsert(
567            "c",
568            vec![
569                VectorPoint {
570                    id: "a".into(),
571                    vector: vec![1.0, 0.0, 0.0, 0.0],
572                    payload: HashMap::from([
573                        ("role".into(), serde_json::json!("user")),
574                        ("conv_id".into(), serde_json::json!(1)),
575                    ]),
576                },
577                VectorPoint {
578                    id: "b".into(),
579                    vector: vec![1.0, 0.0, 0.0, 0.0],
580                    payload: HashMap::from([
581                        ("role".into(), serde_json::json!("user")),
582                        ("conv_id".into(), serde_json::json!(2)),
583                    ]),
584                },
585                VectorPoint {
586                    id: "c".into(),
587                    vector: vec![1.0, 0.0, 0.0, 0.0],
588                    payload: HashMap::from([
589                        ("role".into(), serde_json::json!("assistant")),
590                        ("conv_id".into(), serde_json::json!(1)),
591                    ]),
592                },
593            ],
594        )
595        .await
596        .unwrap();
597
598        let filter = VectorFilter {
599            must: vec![FieldCondition {
600                field: "role".into(),
601                value: FieldValue::Text("user".into()),
602            }],
603            must_not: vec![FieldCondition {
604                field: "conv_id".into(),
605                value: FieldValue::Integer(2),
606            }],
607        };
608        let results = vs
609            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
610            .await
611            .unwrap();
612        // Only "a": role=user AND conv_id != 2
613        assert_eq!(results.len(), 1);
614        assert_eq!(results[0].id, "a");
615    }
616
617    #[tokio::test]
618    async fn scroll_all_missing_key_field() {
619        let (vs, _) = setup().await;
620        vs.ensure_collection("c", 4).await.unwrap();
621        vs.upsert(
622            "c",
623            vec![VectorPoint {
624                id: "p1".into(),
625                vector: vec![1.0, 0.0, 0.0, 0.0],
626                payload: HashMap::from([("other".into(), serde_json::json!("value"))]),
627            }],
628        )
629        .await
630        .unwrap();
631        // key_field "text" doesn't exist in payload → point excluded from result
632        let result = vs.scroll_all("c", "text").await.unwrap();
633        assert!(
634            result.is_empty(),
635            "points without the key field must not appear in scroll result"
636        );
637    }
638
639    #[tokio::test]
640    async fn delete_by_ids_empty_and_nonexistent() {
641        let (vs, _) = setup().await;
642        vs.ensure_collection("c", 4).await.unwrap();
643        vs.upsert(
644            "c",
645            vec![VectorPoint {
646                id: "a".into(),
647                vector: vec![1.0, 0.0, 0.0, 0.0],
648                payload: HashMap::new(),
649            }],
650        )
651        .await
652        .unwrap();
653
654        // Empty list: no-op, must succeed
655        vs.delete_by_ids("c", vec![]).await.unwrap();
656
657        // Non-existent id: must succeed (idempotent)
658        vs.delete_by_ids("c", vec!["nonexistent".into()])
659            .await
660            .unwrap();
661
662        // Original point still present
663        let results = vs
664            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
665            .await
666            .unwrap();
667        assert_eq!(results.len(), 1);
668        assert_eq!(results[0].id, "a");
669    }
670
671    #[tokio::test]
672    async fn search_corrupt_blob_skipped() {
673        let (vs, store) = setup().await;
674        vs.ensure_collection("c", 4).await.unwrap();
675
676        // Insert a valid point first
677        vs.upsert(
678            "c",
679            vec![VectorPoint {
680                id: "valid".into(),
681                vector: vec![1.0, 0.0, 0.0, 0.0],
682                payload: HashMap::new(),
683            }],
684        )
685        .await
686        .unwrap();
687
688        // Insert raw invalid bytes directly into vector_points table
689        // 3 bytes cannot be cast to f32 (needs multiples of 4)
690        let corrupt_blob: Vec<u8> = vec![0xFF, 0xFE, 0xFD];
691        let payload_json = r"{}";
692        sqlx::query(
693            "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?)",
694        )
695        .bind("corrupt")
696        .bind("c")
697        .bind(&corrupt_blob)
698        .bind(payload_json)
699        .execute(store.pool())
700        .await
701        .unwrap();
702
703        // Search must not panic and must skip the corrupt point
704        let results = vs
705            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
706            .await
707            .unwrap();
708        assert_eq!(results.len(), 1);
709        assert_eq!(results[0].id, "valid");
710    }
711}