Skip to main content

zeph_memory/
sqlite_vector_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7
8use sqlx::SqlitePool;
9
10use crate::vector_store::{
11    FieldValue, ScoredVectorPoint, ScrollResult, VectorFilter, VectorPoint, VectorStore,
12    VectorStoreError,
13};
14
15pub struct SqliteVectorStore {
16    pool: SqlitePool,
17}
18
19impl SqliteVectorStore {
20    #[must_use]
21    pub fn new(pool: SqlitePool) -> Self {
22        Self { pool }
23    }
24}
25
26fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
27    if a.len() != b.len() || a.is_empty() {
28        return 0.0;
29    }
30    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
31    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
32    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
33    if norm_a == 0.0 || norm_b == 0.0 {
34        return 0.0;
35    }
36    dot / (norm_a * norm_b)
37}
38
39fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
40    for cond in &filter.must {
41        let Some(val) = payload.get(&cond.field) else {
42            return false;
43        };
44        let matches = match &cond.value {
45            FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
46            FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
47        };
48        if !matches {
49            return false;
50        }
51    }
52    for cond in &filter.must_not {
53        let Some(val) = payload.get(&cond.field) else {
54            continue;
55        };
56        let matches = match &cond.value {
57            FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
58            FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
59        };
60        if matches {
61            return false;
62        }
63    }
64    true
65}
66
67type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
68
69impl VectorStore for SqliteVectorStore {
70    fn ensure_collection(
71        &self,
72        collection: &str,
73        _vector_size: u64,
74    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
75        let collection = collection.to_owned();
76        Box::pin(async move {
77            sqlx::query("INSERT OR IGNORE INTO vector_collections (name) VALUES (?)")
78                .bind(&collection)
79                .execute(&self.pool)
80                .await
81                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
82            Ok(())
83        })
84    }
85
86    fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
87        let collection = collection.to_owned();
88        Box::pin(async move {
89            let row: (i64,) =
90                sqlx::query_as("SELECT COUNT(*) FROM vector_collections WHERE name = ?")
91                    .bind(&collection)
92                    .fetch_one(&self.pool)
93                    .await
94                    .map_err(|e| VectorStoreError::Connection(e.to_string()))?;
95            Ok(row.0 > 0)
96        })
97    }
98
99    fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
100        let collection = collection.to_owned();
101        Box::pin(async move {
102            sqlx::query("DELETE FROM vector_points WHERE collection = ?")
103                .bind(&collection)
104                .execute(&self.pool)
105                .await
106                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
107            sqlx::query("DELETE FROM vector_collections WHERE name = ?")
108                .bind(&collection)
109                .execute(&self.pool)
110                .await
111                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
112            Ok(())
113        })
114    }
115
116    fn upsert(
117        &self,
118        collection: &str,
119        points: Vec<VectorPoint>,
120    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
121        let collection = collection.to_owned();
122        Box::pin(async move {
123            for point in points {
124                let vector_bytes: Vec<u8> = bytemuck::cast_slice(&point.vector).to_vec();
125                let payload_json = serde_json::to_string(&point.payload)
126                    .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
127                sqlx::query(
128                    "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?) \
129                     ON CONFLICT(collection, id) DO UPDATE SET vector = excluded.vector, payload = excluded.payload",
130                )
131                .bind(&point.id)
132                .bind(&collection)
133                .bind(&vector_bytes)
134                .bind(&payload_json)
135                .execute(&self.pool)
136                .await
137                .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
138            }
139            Ok(())
140        })
141    }
142
143    fn search(
144        &self,
145        collection: &str,
146        vector: Vec<f32>,
147        limit: u64,
148        filter: Option<VectorFilter>,
149    ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
150        let collection = collection.to_owned();
151        Box::pin(async move {
152            let rows: Vec<(String, Vec<u8>, String)> = sqlx::query_as(
153                "SELECT id, vector, payload FROM vector_points WHERE collection = ?",
154            )
155            .bind(&collection)
156            .fetch_all(&self.pool)
157            .await
158            .map_err(|e| VectorStoreError::Search(e.to_string()))?;
159
160            let limit_usize = usize::try_from(limit).unwrap_or(usize::MAX);
161            let mut scored: Vec<ScoredVectorPoint> = rows
162                .into_iter()
163                .filter_map(|(id, blob, payload_str)| {
164                    let Ok(stored) = bytemuck::try_cast_slice::<u8, f32>(&blob) else {
165                        return None;
166                    };
167                    let payload: HashMap<String, serde_json::Value> =
168                        serde_json::from_str(&payload_str).unwrap_or_default();
169
170                    if filter
171                        .as_ref()
172                        .is_some_and(|f| !matches_filter(&payload, f))
173                    {
174                        return None;
175                    }
176
177                    let score = cosine_similarity(&vector, stored);
178                    Some(ScoredVectorPoint { id, score, payload })
179                })
180                .collect();
181
182            scored.sort_by(|a, b| {
183                b.score
184                    .partial_cmp(&a.score)
185                    .unwrap_or(std::cmp::Ordering::Equal)
186            });
187            scored.truncate(limit_usize);
188            Ok(scored)
189        })
190    }
191
192    fn delete_by_ids(
193        &self,
194        collection: &str,
195        ids: Vec<String>,
196    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
197        let collection = collection.to_owned();
198        Box::pin(async move {
199            for id in ids {
200                sqlx::query("DELETE FROM vector_points WHERE collection = ? AND id = ?")
201                    .bind(&collection)
202                    .bind(&id)
203                    .execute(&self.pool)
204                    .await
205                    .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
206            }
207            Ok(())
208        })
209    }
210
211    fn scroll_all(
212        &self,
213        collection: &str,
214        key_field: &str,
215    ) -> BoxFuture<'_, Result<ScrollResult, VectorStoreError>> {
216        let collection = collection.to_owned();
217        let key_field = key_field.to_owned();
218        Box::pin(async move {
219            let rows: Vec<(String, String)> =
220                sqlx::query_as("SELECT id, payload FROM vector_points WHERE collection = ?")
221                    .bind(&collection)
222                    .fetch_all(&self.pool)
223                    .await
224                    .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
225
226            let mut result = ScrollResult::new();
227            for (id, payload_str) in rows {
228                let payload: HashMap<String, serde_json::Value> =
229                    serde_json::from_str(&payload_str).unwrap_or_default();
230                if let Some(val) = payload.get(&key_field) {
231                    let mut map = HashMap::new();
232                    map.insert(
233                        key_field.clone(),
234                        val.as_str().unwrap_or_default().to_owned(),
235                    );
236                    result.insert(id, map);
237                }
238            }
239            Ok(result)
240        })
241    }
242
243    fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
244        Box::pin(async move {
245            sqlx::query_scalar::<_, i32>("SELECT 1")
246                .fetch_one(&self.pool)
247                .await
248                .map(|_| true)
249                .map_err(|e| VectorStoreError::Collection(e.to_string()))
250        })
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use crate::sqlite::SqliteStore;
258    use crate::vector_store::FieldCondition;
259
260    async fn setup() -> (SqliteVectorStore, SqliteStore) {
261        let store = SqliteStore::new(":memory:").await.unwrap();
262        let pool = store.pool().clone();
263        let vs = SqliteVectorStore::new(pool);
264        (vs, store)
265    }
266
267    #[tokio::test]
268    async fn ensure_and_exists() {
269        let (vs, _) = setup().await;
270        assert!(!vs.collection_exists("col1").await.unwrap());
271        vs.ensure_collection("col1", 4).await.unwrap();
272        assert!(vs.collection_exists("col1").await.unwrap());
273        // idempotent
274        vs.ensure_collection("col1", 4).await.unwrap();
275        assert!(vs.collection_exists("col1").await.unwrap());
276    }
277
278    #[tokio::test]
279    async fn delete_collection() {
280        let (vs, _) = setup().await;
281        vs.ensure_collection("col1", 4).await.unwrap();
282        vs.upsert(
283            "col1",
284            vec![VectorPoint {
285                id: "p1".into(),
286                vector: vec![1.0, 0.0, 0.0, 0.0],
287                payload: HashMap::new(),
288            }],
289        )
290        .await
291        .unwrap();
292        vs.delete_collection("col1").await.unwrap();
293        assert!(!vs.collection_exists("col1").await.unwrap());
294    }
295
296    #[tokio::test]
297    async fn upsert_and_search() {
298        let (vs, _) = setup().await;
299        vs.ensure_collection("c", 4).await.unwrap();
300        vs.upsert(
301            "c",
302            vec![
303                VectorPoint {
304                    id: "a".into(),
305                    vector: vec![1.0, 0.0, 0.0, 0.0],
306                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
307                },
308                VectorPoint {
309                    id: "b".into(),
310                    vector: vec![0.0, 1.0, 0.0, 0.0],
311                    payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
312                },
313            ],
314        )
315        .await
316        .unwrap();
317
318        let results = vs
319            .search("c", vec![1.0, 0.0, 0.0, 0.0], 2, None)
320            .await
321            .unwrap();
322        assert_eq!(results.len(), 2);
323        assert_eq!(results[0].id, "a");
324        assert!((results[0].score - 1.0).abs() < 1e-5);
325    }
326
327    #[tokio::test]
328    async fn search_with_filter() {
329        let (vs, _) = setup().await;
330        vs.ensure_collection("c", 4).await.unwrap();
331        vs.upsert(
332            "c",
333            vec![
334                VectorPoint {
335                    id: "a".into(),
336                    vector: vec![1.0, 0.0, 0.0, 0.0],
337                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
338                },
339                VectorPoint {
340                    id: "b".into(),
341                    vector: vec![1.0, 0.0, 0.0, 0.0],
342                    payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
343                },
344            ],
345        )
346        .await
347        .unwrap();
348
349        let filter = VectorFilter {
350            must: vec![FieldCondition {
351                field: "role".into(),
352                value: FieldValue::Text("user".into()),
353            }],
354            must_not: vec![],
355        };
356        let results = vs
357            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
358            .await
359            .unwrap();
360        assert_eq!(results.len(), 1);
361        assert_eq!(results[0].id, "a");
362    }
363
364    #[tokio::test]
365    async fn delete_by_ids() {
366        let (vs, _) = setup().await;
367        vs.ensure_collection("c", 4).await.unwrap();
368        vs.upsert(
369            "c",
370            vec![
371                VectorPoint {
372                    id: "a".into(),
373                    vector: vec![1.0, 0.0, 0.0, 0.0],
374                    payload: HashMap::new(),
375                },
376                VectorPoint {
377                    id: "b".into(),
378                    vector: vec![0.0, 1.0, 0.0, 0.0],
379                    payload: HashMap::new(),
380                },
381            ],
382        )
383        .await
384        .unwrap();
385        vs.delete_by_ids("c", vec!["a".into()]).await.unwrap();
386        let results = vs
387            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
388            .await
389            .unwrap();
390        assert_eq!(results.len(), 1);
391        assert_eq!(results[0].id, "b");
392    }
393
394    #[tokio::test]
395    async fn scroll_all() {
396        let (vs, _) = setup().await;
397        vs.ensure_collection("c", 4).await.unwrap();
398        vs.upsert(
399            "c",
400            vec![VectorPoint {
401                id: "p1".into(),
402                vector: vec![1.0, 0.0, 0.0, 0.0],
403                payload: HashMap::from([("text".into(), serde_json::json!("hello"))]),
404            }],
405        )
406        .await
407        .unwrap();
408        let result = vs.scroll_all("c", "text").await.unwrap();
409        assert_eq!(result.len(), 1);
410        assert_eq!(result["p1"]["text"], "hello");
411    }
412
413    #[tokio::test]
414    async fn upsert_updates_existing() {
415        let (vs, _) = setup().await;
416        vs.ensure_collection("c", 4).await.unwrap();
417        vs.upsert(
418            "c",
419            vec![VectorPoint {
420                id: "p1".into(),
421                vector: vec![1.0, 0.0, 0.0, 0.0],
422                payload: HashMap::from([("v".into(), serde_json::json!(1))]),
423            }],
424        )
425        .await
426        .unwrap();
427        vs.upsert(
428            "c",
429            vec![VectorPoint {
430                id: "p1".into(),
431                vector: vec![0.0, 1.0, 0.0, 0.0],
432                payload: HashMap::from([("v".into(), serde_json::json!(2))]),
433            }],
434        )
435        .await
436        .unwrap();
437        let results = vs
438            .search("c", vec![0.0, 1.0, 0.0, 0.0], 1, None)
439            .await
440            .unwrap();
441        assert_eq!(results.len(), 1);
442        assert!((results[0].score - 1.0).abs() < 1e-5);
443    }
444
445    #[test]
446    fn cosine_similarity_orthogonal() {
447        assert_eq!(cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]), 0.0);
448    }
449
450    #[test]
451    fn cosine_similarity_identical() {
452        let v = vec![3.0, 4.0];
453        assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
454    }
455
456    #[test]
457    fn cosine_similarity_zero_vector() {
458        assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 0.0]), 0.0);
459    }
460
461    #[test]
462    fn cosine_similarity_length_mismatch() {
463        assert_eq!(cosine_similarity(&[1.0], &[1.0, 0.0]), 0.0);
464    }
465
466    #[tokio::test]
467    async fn search_with_must_not_filter() {
468        let (vs, _) = setup().await;
469        vs.ensure_collection("c", 4).await.unwrap();
470        vs.upsert(
471            "c",
472            vec![
473                VectorPoint {
474                    id: "a".into(),
475                    vector: vec![1.0, 0.0, 0.0, 0.0],
476                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
477                },
478                VectorPoint {
479                    id: "b".into(),
480                    vector: vec![1.0, 0.0, 0.0, 0.0],
481                    payload: HashMap::from([("role".into(), serde_json::json!("system"))]),
482                },
483            ],
484        )
485        .await
486        .unwrap();
487
488        let filter = VectorFilter {
489            must: vec![],
490            must_not: vec![FieldCondition {
491                field: "role".into(),
492                value: FieldValue::Text("system".into()),
493            }],
494        };
495        let results = vs
496            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
497            .await
498            .unwrap();
499        assert_eq!(results.len(), 1);
500        assert_eq!(results[0].id, "a");
501    }
502
503    #[tokio::test]
504    async fn search_with_integer_filter() {
505        let (vs, _) = setup().await;
506        vs.ensure_collection("c", 4).await.unwrap();
507        vs.upsert(
508            "c",
509            vec![
510                VectorPoint {
511                    id: "a".into(),
512                    vector: vec![1.0, 0.0, 0.0, 0.0],
513                    payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
514                },
515                VectorPoint {
516                    id: "b".into(),
517                    vector: vec![1.0, 0.0, 0.0, 0.0],
518                    payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
519                },
520            ],
521        )
522        .await
523        .unwrap();
524
525        let filter = VectorFilter {
526            must: vec![FieldCondition {
527                field: "conv_id".into(),
528                value: FieldValue::Integer(1),
529            }],
530            must_not: vec![],
531        };
532        let results = vs
533            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
534            .await
535            .unwrap();
536        assert_eq!(results.len(), 1);
537        assert_eq!(results[0].id, "a");
538    }
539
540    #[tokio::test]
541    async fn search_empty_collection() {
542        let (vs, _) = setup().await;
543        vs.ensure_collection("c", 4).await.unwrap();
544        let results = vs
545            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
546            .await
547            .unwrap();
548        assert!(results.is_empty());
549    }
550
551    #[tokio::test]
552    async fn search_with_must_not_integer_filter() {
553        let (vs, _) = setup().await;
554        vs.ensure_collection("c", 4).await.unwrap();
555        vs.upsert(
556            "c",
557            vec![
558                VectorPoint {
559                    id: "a".into(),
560                    vector: vec![1.0, 0.0, 0.0, 0.0],
561                    payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
562                },
563                VectorPoint {
564                    id: "b".into(),
565                    vector: vec![1.0, 0.0, 0.0, 0.0],
566                    payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
567                },
568            ],
569        )
570        .await
571        .unwrap();
572
573        let filter = VectorFilter {
574            must: vec![],
575            must_not: vec![FieldCondition {
576                field: "conv_id".into(),
577                value: FieldValue::Integer(1),
578            }],
579        };
580        let results = vs
581            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
582            .await
583            .unwrap();
584        assert_eq!(results.len(), 1);
585        assert_eq!(results[0].id, "b");
586    }
587
588    #[tokio::test]
589    async fn search_with_combined_must_and_must_not() {
590        let (vs, _) = setup().await;
591        vs.ensure_collection("c", 4).await.unwrap();
592        vs.upsert(
593            "c",
594            vec![
595                VectorPoint {
596                    id: "a".into(),
597                    vector: vec![1.0, 0.0, 0.0, 0.0],
598                    payload: HashMap::from([
599                        ("role".into(), serde_json::json!("user")),
600                        ("conv_id".into(), serde_json::json!(1)),
601                    ]),
602                },
603                VectorPoint {
604                    id: "b".into(),
605                    vector: vec![1.0, 0.0, 0.0, 0.0],
606                    payload: HashMap::from([
607                        ("role".into(), serde_json::json!("user")),
608                        ("conv_id".into(), serde_json::json!(2)),
609                    ]),
610                },
611                VectorPoint {
612                    id: "c".into(),
613                    vector: vec![1.0, 0.0, 0.0, 0.0],
614                    payload: HashMap::from([
615                        ("role".into(), serde_json::json!("assistant")),
616                        ("conv_id".into(), serde_json::json!(1)),
617                    ]),
618                },
619            ],
620        )
621        .await
622        .unwrap();
623
624        let filter = VectorFilter {
625            must: vec![FieldCondition {
626                field: "role".into(),
627                value: FieldValue::Text("user".into()),
628            }],
629            must_not: vec![FieldCondition {
630                field: "conv_id".into(),
631                value: FieldValue::Integer(2),
632            }],
633        };
634        let results = vs
635            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
636            .await
637            .unwrap();
638        // Only "a": role=user AND conv_id != 2
639        assert_eq!(results.len(), 1);
640        assert_eq!(results[0].id, "a");
641    }
642
643    #[tokio::test]
644    async fn scroll_all_missing_key_field() {
645        let (vs, _) = setup().await;
646        vs.ensure_collection("c", 4).await.unwrap();
647        vs.upsert(
648            "c",
649            vec![VectorPoint {
650                id: "p1".into(),
651                vector: vec![1.0, 0.0, 0.0, 0.0],
652                payload: HashMap::from([("other".into(), serde_json::json!("value"))]),
653            }],
654        )
655        .await
656        .unwrap();
657        // key_field "text" doesn't exist in payload → point excluded from result
658        let result = vs.scroll_all("c", "text").await.unwrap();
659        assert!(
660            result.is_empty(),
661            "points without the key field must not appear in scroll result"
662        );
663    }
664
665    #[tokio::test]
666    async fn delete_by_ids_empty_and_nonexistent() {
667        let (vs, _) = setup().await;
668        vs.ensure_collection("c", 4).await.unwrap();
669        vs.upsert(
670            "c",
671            vec![VectorPoint {
672                id: "a".into(),
673                vector: vec![1.0, 0.0, 0.0, 0.0],
674                payload: HashMap::new(),
675            }],
676        )
677        .await
678        .unwrap();
679
680        // Empty list: no-op, must succeed
681        vs.delete_by_ids("c", vec![]).await.unwrap();
682
683        // Non-existent id: must succeed (idempotent)
684        vs.delete_by_ids("c", vec!["nonexistent".into()])
685            .await
686            .unwrap();
687
688        // Original point still present
689        let results = vs
690            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
691            .await
692            .unwrap();
693        assert_eq!(results.len(), 1);
694        assert_eq!(results[0].id, "a");
695    }
696
697    #[tokio::test]
698    async fn search_corrupt_blob_skipped() {
699        let (vs, store) = setup().await;
700        vs.ensure_collection("c", 4).await.unwrap();
701
702        // Insert a valid point first
703        vs.upsert(
704            "c",
705            vec![VectorPoint {
706                id: "valid".into(),
707                vector: vec![1.0, 0.0, 0.0, 0.0],
708                payload: HashMap::new(),
709            }],
710        )
711        .await
712        .unwrap();
713
714        // Insert raw invalid bytes directly into vector_points table
715        // 3 bytes cannot be cast to f32 (needs multiples of 4)
716        let corrupt_blob: Vec<u8> = vec![0xFF, 0xFE, 0xFD];
717        let payload_json = r#"{}"#;
718        sqlx::query(
719            "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?)",
720        )
721        .bind("corrupt")
722        .bind("c")
723        .bind(&corrupt_blob)
724        .bind(payload_json)
725        .execute(store.pool())
726        .await
727        .unwrap();
728
729        // Search must not panic and must skip the corrupt point
730        let results = vs
731            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
732            .await
733            .unwrap();
734        assert_eq!(results.len(), 1);
735        assert_eq!(results[0].id, "valid");
736    }
737}