Skip to main content

zeph_memory/
sqlite_vector_store.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4
5use sqlx::SqlitePool;
6
7use crate::vector_store::{
8    FieldValue, ScoredVectorPoint, ScrollResult, VectorFilter, VectorPoint, VectorStore,
9    VectorStoreError,
10};
11
12pub struct SqliteVectorStore {
13    pool: SqlitePool,
14}
15
16impl SqliteVectorStore {
17    #[must_use]
18    pub fn new(pool: SqlitePool) -> Self {
19        Self { pool }
20    }
21}
22
23fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
24    if a.len() != b.len() || a.is_empty() {
25        return 0.0;
26    }
27    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
28    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
29    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
30    if norm_a == 0.0 || norm_b == 0.0 {
31        return 0.0;
32    }
33    dot / (norm_a * norm_b)
34}
35
36fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
37    for cond in &filter.must {
38        let Some(val) = payload.get(&cond.field) else {
39            return false;
40        };
41        let matches = match &cond.value {
42            FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
43            FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
44        };
45        if !matches {
46            return false;
47        }
48    }
49    for cond in &filter.must_not {
50        let Some(val) = payload.get(&cond.field) else {
51            continue;
52        };
53        let matches = match &cond.value {
54            FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
55            FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
56        };
57        if matches {
58            return false;
59        }
60    }
61    true
62}
63
64type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
65
66impl VectorStore for SqliteVectorStore {
67    fn ensure_collection(
68        &self,
69        collection: &str,
70        _vector_size: u64,
71    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
72        let collection = collection.to_owned();
73        Box::pin(async move {
74            sqlx::query("INSERT OR IGNORE INTO vector_collections (name) VALUES (?)")
75                .bind(&collection)
76                .execute(&self.pool)
77                .await
78                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
79            Ok(())
80        })
81    }
82
83    fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
84        let collection = collection.to_owned();
85        Box::pin(async move {
86            let row: (i64,) =
87                sqlx::query_as("SELECT COUNT(*) FROM vector_collections WHERE name = ?")
88                    .bind(&collection)
89                    .fetch_one(&self.pool)
90                    .await
91                    .map_err(|e| VectorStoreError::Connection(e.to_string()))?;
92            Ok(row.0 > 0)
93        })
94    }
95
96    fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
97        let collection = collection.to_owned();
98        Box::pin(async move {
99            sqlx::query("DELETE FROM vector_points WHERE collection = ?")
100                .bind(&collection)
101                .execute(&self.pool)
102                .await
103                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
104            sqlx::query("DELETE FROM vector_collections WHERE name = ?")
105                .bind(&collection)
106                .execute(&self.pool)
107                .await
108                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
109            Ok(())
110        })
111    }
112
113    fn upsert(
114        &self,
115        collection: &str,
116        points: Vec<VectorPoint>,
117    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
118        let collection = collection.to_owned();
119        Box::pin(async move {
120            for point in points {
121                let vector_bytes: Vec<u8> = bytemuck::cast_slice(&point.vector).to_vec();
122                let payload_json = serde_json::to_string(&point.payload)
123                    .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
124                sqlx::query(
125                    "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?) \
126                     ON CONFLICT(collection, id) DO UPDATE SET vector = excluded.vector, payload = excluded.payload",
127                )
128                .bind(&point.id)
129                .bind(&collection)
130                .bind(&vector_bytes)
131                .bind(&payload_json)
132                .execute(&self.pool)
133                .await
134                .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
135            }
136            Ok(())
137        })
138    }
139
140    fn search(
141        &self,
142        collection: &str,
143        vector: Vec<f32>,
144        limit: u64,
145        filter: Option<VectorFilter>,
146    ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
147        let collection = collection.to_owned();
148        Box::pin(async move {
149            let rows: Vec<(String, Vec<u8>, String)> = sqlx::query_as(
150                "SELECT id, vector, payload FROM vector_points WHERE collection = ?",
151            )
152            .bind(&collection)
153            .fetch_all(&self.pool)
154            .await
155            .map_err(|e| VectorStoreError::Search(e.to_string()))?;
156
157            let limit_usize = usize::try_from(limit).unwrap_or(usize::MAX);
158            let mut scored: Vec<ScoredVectorPoint> = rows
159                .into_iter()
160                .filter_map(|(id, blob, payload_str)| {
161                    let Ok(stored) = bytemuck::try_cast_slice::<u8, f32>(&blob) else {
162                        return None;
163                    };
164                    let payload: HashMap<String, serde_json::Value> =
165                        serde_json::from_str(&payload_str).unwrap_or_default();
166
167                    if filter
168                        .as_ref()
169                        .is_some_and(|f| !matches_filter(&payload, f))
170                    {
171                        return None;
172                    }
173
174                    let score = cosine_similarity(&vector, stored);
175                    Some(ScoredVectorPoint { id, score, payload })
176                })
177                .collect();
178
179            scored.sort_by(|a, b| {
180                b.score
181                    .partial_cmp(&a.score)
182                    .unwrap_or(std::cmp::Ordering::Equal)
183            });
184            scored.truncate(limit_usize);
185            Ok(scored)
186        })
187    }
188
189    fn delete_by_ids(
190        &self,
191        collection: &str,
192        ids: Vec<String>,
193    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
194        let collection = collection.to_owned();
195        Box::pin(async move {
196            for id in ids {
197                sqlx::query("DELETE FROM vector_points WHERE collection = ? AND id = ?")
198                    .bind(&collection)
199                    .bind(&id)
200                    .execute(&self.pool)
201                    .await
202                    .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
203            }
204            Ok(())
205        })
206    }
207
208    fn scroll_all(
209        &self,
210        collection: &str,
211        key_field: &str,
212    ) -> BoxFuture<'_, Result<ScrollResult, VectorStoreError>> {
213        let collection = collection.to_owned();
214        let key_field = key_field.to_owned();
215        Box::pin(async move {
216            let rows: Vec<(String, String)> =
217                sqlx::query_as("SELECT id, payload FROM vector_points WHERE collection = ?")
218                    .bind(&collection)
219                    .fetch_all(&self.pool)
220                    .await
221                    .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
222
223            let mut result = ScrollResult::new();
224            for (id, payload_str) in rows {
225                let payload: HashMap<String, serde_json::Value> =
226                    serde_json::from_str(&payload_str).unwrap_or_default();
227                if let Some(val) = payload.get(&key_field) {
228                    let mut map = HashMap::new();
229                    map.insert(
230                        key_field.clone(),
231                        val.as_str().unwrap_or_default().to_owned(),
232                    );
233                    result.insert(id, map);
234                }
235            }
236            Ok(result)
237        })
238    }
239
240    fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
241        Box::pin(async move {
242            sqlx::query_scalar::<_, i32>("SELECT 1")
243                .fetch_one(&self.pool)
244                .await
245                .map(|_| true)
246                .map_err(|e| VectorStoreError::Collection(e.to_string()))
247        })
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::sqlite::SqliteStore;
255    use crate::vector_store::FieldCondition;
256
257    async fn setup() -> (SqliteVectorStore, SqliteStore) {
258        let store = SqliteStore::new(":memory:").await.unwrap();
259        let pool = store.pool().clone();
260        let vs = SqliteVectorStore::new(pool);
261        (vs, store)
262    }
263
264    #[tokio::test]
265    async fn ensure_and_exists() {
266        let (vs, _) = setup().await;
267        assert!(!vs.collection_exists("col1").await.unwrap());
268        vs.ensure_collection("col1", 4).await.unwrap();
269        assert!(vs.collection_exists("col1").await.unwrap());
270        // idempotent
271        vs.ensure_collection("col1", 4).await.unwrap();
272        assert!(vs.collection_exists("col1").await.unwrap());
273    }
274
275    #[tokio::test]
276    async fn delete_collection() {
277        let (vs, _) = setup().await;
278        vs.ensure_collection("col1", 4).await.unwrap();
279        vs.upsert(
280            "col1",
281            vec![VectorPoint {
282                id: "p1".into(),
283                vector: vec![1.0, 0.0, 0.0, 0.0],
284                payload: HashMap::new(),
285            }],
286        )
287        .await
288        .unwrap();
289        vs.delete_collection("col1").await.unwrap();
290        assert!(!vs.collection_exists("col1").await.unwrap());
291    }
292
293    #[tokio::test]
294    async fn upsert_and_search() {
295        let (vs, _) = setup().await;
296        vs.ensure_collection("c", 4).await.unwrap();
297        vs.upsert(
298            "c",
299            vec![
300                VectorPoint {
301                    id: "a".into(),
302                    vector: vec![1.0, 0.0, 0.0, 0.0],
303                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
304                },
305                VectorPoint {
306                    id: "b".into(),
307                    vector: vec![0.0, 1.0, 0.0, 0.0],
308                    payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
309                },
310            ],
311        )
312        .await
313        .unwrap();
314
315        let results = vs
316            .search("c", vec![1.0, 0.0, 0.0, 0.0], 2, None)
317            .await
318            .unwrap();
319        assert_eq!(results.len(), 2);
320        assert_eq!(results[0].id, "a");
321        assert!((results[0].score - 1.0).abs() < 1e-5);
322    }
323
324    #[tokio::test]
325    async fn search_with_filter() {
326        let (vs, _) = setup().await;
327        vs.ensure_collection("c", 4).await.unwrap();
328        vs.upsert(
329            "c",
330            vec![
331                VectorPoint {
332                    id: "a".into(),
333                    vector: vec![1.0, 0.0, 0.0, 0.0],
334                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
335                },
336                VectorPoint {
337                    id: "b".into(),
338                    vector: vec![1.0, 0.0, 0.0, 0.0],
339                    payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
340                },
341            ],
342        )
343        .await
344        .unwrap();
345
346        let filter = VectorFilter {
347            must: vec![FieldCondition {
348                field: "role".into(),
349                value: FieldValue::Text("user".into()),
350            }],
351            must_not: vec![],
352        };
353        let results = vs
354            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
355            .await
356            .unwrap();
357        assert_eq!(results.len(), 1);
358        assert_eq!(results[0].id, "a");
359    }
360
361    #[tokio::test]
362    async fn delete_by_ids() {
363        let (vs, _) = setup().await;
364        vs.ensure_collection("c", 4).await.unwrap();
365        vs.upsert(
366            "c",
367            vec![
368                VectorPoint {
369                    id: "a".into(),
370                    vector: vec![1.0, 0.0, 0.0, 0.0],
371                    payload: HashMap::new(),
372                },
373                VectorPoint {
374                    id: "b".into(),
375                    vector: vec![0.0, 1.0, 0.0, 0.0],
376                    payload: HashMap::new(),
377                },
378            ],
379        )
380        .await
381        .unwrap();
382        vs.delete_by_ids("c", vec!["a".into()]).await.unwrap();
383        let results = vs
384            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
385            .await
386            .unwrap();
387        assert_eq!(results.len(), 1);
388        assert_eq!(results[0].id, "b");
389    }
390
391    #[tokio::test]
392    async fn scroll_all() {
393        let (vs, _) = setup().await;
394        vs.ensure_collection("c", 4).await.unwrap();
395        vs.upsert(
396            "c",
397            vec![VectorPoint {
398                id: "p1".into(),
399                vector: vec![1.0, 0.0, 0.0, 0.0],
400                payload: HashMap::from([("text".into(), serde_json::json!("hello"))]),
401            }],
402        )
403        .await
404        .unwrap();
405        let result = vs.scroll_all("c", "text").await.unwrap();
406        assert_eq!(result.len(), 1);
407        assert_eq!(result["p1"]["text"], "hello");
408    }
409
410    #[tokio::test]
411    async fn upsert_updates_existing() {
412        let (vs, _) = setup().await;
413        vs.ensure_collection("c", 4).await.unwrap();
414        vs.upsert(
415            "c",
416            vec![VectorPoint {
417                id: "p1".into(),
418                vector: vec![1.0, 0.0, 0.0, 0.0],
419                payload: HashMap::from([("v".into(), serde_json::json!(1))]),
420            }],
421        )
422        .await
423        .unwrap();
424        vs.upsert(
425            "c",
426            vec![VectorPoint {
427                id: "p1".into(),
428                vector: vec![0.0, 1.0, 0.0, 0.0],
429                payload: HashMap::from([("v".into(), serde_json::json!(2))]),
430            }],
431        )
432        .await
433        .unwrap();
434        let results = vs
435            .search("c", vec![0.0, 1.0, 0.0, 0.0], 1, None)
436            .await
437            .unwrap();
438        assert_eq!(results.len(), 1);
439        assert!((results[0].score - 1.0).abs() < 1e-5);
440    }
441
442    #[test]
443    fn cosine_similarity_orthogonal() {
444        assert_eq!(cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]), 0.0);
445    }
446
447    #[test]
448    fn cosine_similarity_identical() {
449        let v = vec![3.0, 4.0];
450        assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
451    }
452
453    #[test]
454    fn cosine_similarity_zero_vector() {
455        assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 0.0]), 0.0);
456    }
457
458    #[test]
459    fn cosine_similarity_length_mismatch() {
460        assert_eq!(cosine_similarity(&[1.0], &[1.0, 0.0]), 0.0);
461    }
462
463    #[tokio::test]
464    async fn search_with_must_not_filter() {
465        let (vs, _) = setup().await;
466        vs.ensure_collection("c", 4).await.unwrap();
467        vs.upsert(
468            "c",
469            vec![
470                VectorPoint {
471                    id: "a".into(),
472                    vector: vec![1.0, 0.0, 0.0, 0.0],
473                    payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
474                },
475                VectorPoint {
476                    id: "b".into(),
477                    vector: vec![1.0, 0.0, 0.0, 0.0],
478                    payload: HashMap::from([("role".into(), serde_json::json!("system"))]),
479                },
480            ],
481        )
482        .await
483        .unwrap();
484
485        let filter = VectorFilter {
486            must: vec![],
487            must_not: vec![FieldCondition {
488                field: "role".into(),
489                value: FieldValue::Text("system".into()),
490            }],
491        };
492        let results = vs
493            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
494            .await
495            .unwrap();
496        assert_eq!(results.len(), 1);
497        assert_eq!(results[0].id, "a");
498    }
499
500    #[tokio::test]
501    async fn search_with_integer_filter() {
502        let (vs, _) = setup().await;
503        vs.ensure_collection("c", 4).await.unwrap();
504        vs.upsert(
505            "c",
506            vec![
507                VectorPoint {
508                    id: "a".into(),
509                    vector: vec![1.0, 0.0, 0.0, 0.0],
510                    payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
511                },
512                VectorPoint {
513                    id: "b".into(),
514                    vector: vec![1.0, 0.0, 0.0, 0.0],
515                    payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
516                },
517            ],
518        )
519        .await
520        .unwrap();
521
522        let filter = VectorFilter {
523            must: vec![FieldCondition {
524                field: "conv_id".into(),
525                value: FieldValue::Integer(1),
526            }],
527            must_not: vec![],
528        };
529        let results = vs
530            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
531            .await
532            .unwrap();
533        assert_eq!(results.len(), 1);
534        assert_eq!(results[0].id, "a");
535    }
536
537    #[tokio::test]
538    async fn search_empty_collection() {
539        let (vs, _) = setup().await;
540        vs.ensure_collection("c", 4).await.unwrap();
541        let results = vs
542            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
543            .await
544            .unwrap();
545        assert!(results.is_empty());
546    }
547
548    #[tokio::test]
549    async fn search_with_must_not_integer_filter() {
550        let (vs, _) = setup().await;
551        vs.ensure_collection("c", 4).await.unwrap();
552        vs.upsert(
553            "c",
554            vec![
555                VectorPoint {
556                    id: "a".into(),
557                    vector: vec![1.0, 0.0, 0.0, 0.0],
558                    payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
559                },
560                VectorPoint {
561                    id: "b".into(),
562                    vector: vec![1.0, 0.0, 0.0, 0.0],
563                    payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
564                },
565            ],
566        )
567        .await
568        .unwrap();
569
570        let filter = VectorFilter {
571            must: vec![],
572            must_not: vec![FieldCondition {
573                field: "conv_id".into(),
574                value: FieldValue::Integer(1),
575            }],
576        };
577        let results = vs
578            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
579            .await
580            .unwrap();
581        assert_eq!(results.len(), 1);
582        assert_eq!(results[0].id, "b");
583    }
584
585    #[tokio::test]
586    async fn search_with_combined_must_and_must_not() {
587        let (vs, _) = setup().await;
588        vs.ensure_collection("c", 4).await.unwrap();
589        vs.upsert(
590            "c",
591            vec![
592                VectorPoint {
593                    id: "a".into(),
594                    vector: vec![1.0, 0.0, 0.0, 0.0],
595                    payload: HashMap::from([
596                        ("role".into(), serde_json::json!("user")),
597                        ("conv_id".into(), serde_json::json!(1)),
598                    ]),
599                },
600                VectorPoint {
601                    id: "b".into(),
602                    vector: vec![1.0, 0.0, 0.0, 0.0],
603                    payload: HashMap::from([
604                        ("role".into(), serde_json::json!("user")),
605                        ("conv_id".into(), serde_json::json!(2)),
606                    ]),
607                },
608                VectorPoint {
609                    id: "c".into(),
610                    vector: vec![1.0, 0.0, 0.0, 0.0],
611                    payload: HashMap::from([
612                        ("role".into(), serde_json::json!("assistant")),
613                        ("conv_id".into(), serde_json::json!(1)),
614                    ]),
615                },
616            ],
617        )
618        .await
619        .unwrap();
620
621        let filter = VectorFilter {
622            must: vec![FieldCondition {
623                field: "role".into(),
624                value: FieldValue::Text("user".into()),
625            }],
626            must_not: vec![FieldCondition {
627                field: "conv_id".into(),
628                value: FieldValue::Integer(2),
629            }],
630        };
631        let results = vs
632            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
633            .await
634            .unwrap();
635        // Only "a": role=user AND conv_id != 2
636        assert_eq!(results.len(), 1);
637        assert_eq!(results[0].id, "a");
638    }
639
640    #[tokio::test]
641    async fn scroll_all_missing_key_field() {
642        let (vs, _) = setup().await;
643        vs.ensure_collection("c", 4).await.unwrap();
644        vs.upsert(
645            "c",
646            vec![VectorPoint {
647                id: "p1".into(),
648                vector: vec![1.0, 0.0, 0.0, 0.0],
649                payload: HashMap::from([("other".into(), serde_json::json!("value"))]),
650            }],
651        )
652        .await
653        .unwrap();
654        // key_field "text" doesn't exist in payload → point excluded from result
655        let result = vs.scroll_all("c", "text").await.unwrap();
656        assert!(
657            result.is_empty(),
658            "points without the key field must not appear in scroll result"
659        );
660    }
661
662    #[tokio::test]
663    async fn delete_by_ids_empty_and_nonexistent() {
664        let (vs, _) = setup().await;
665        vs.ensure_collection("c", 4).await.unwrap();
666        vs.upsert(
667            "c",
668            vec![VectorPoint {
669                id: "a".into(),
670                vector: vec![1.0, 0.0, 0.0, 0.0],
671                payload: HashMap::new(),
672            }],
673        )
674        .await
675        .unwrap();
676
677        // Empty list: no-op, must succeed
678        vs.delete_by_ids("c", vec![]).await.unwrap();
679
680        // Non-existent id: must succeed (idempotent)
681        vs.delete_by_ids("c", vec!["nonexistent".into()])
682            .await
683            .unwrap();
684
685        // Original point still present
686        let results = vs
687            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
688            .await
689            .unwrap();
690        assert_eq!(results.len(), 1);
691        assert_eq!(results[0].id, "a");
692    }
693
694    #[tokio::test]
695    async fn search_corrupt_blob_skipped() {
696        let (vs, store) = setup().await;
697        vs.ensure_collection("c", 4).await.unwrap();
698
699        // Insert a valid point first
700        vs.upsert(
701            "c",
702            vec![VectorPoint {
703                id: "valid".into(),
704                vector: vec![1.0, 0.0, 0.0, 0.0],
705                payload: HashMap::new(),
706            }],
707        )
708        .await
709        .unwrap();
710
711        // Insert raw invalid bytes directly into vector_points table
712        // 3 bytes cannot be cast to f32 (needs multiples of 4)
713        let corrupt_blob: Vec<u8> = vec![0xFF, 0xFE, 0xFD];
714        let payload_json = r#"{}"#;
715        sqlx::query(
716            "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?)",
717        )
718        .bind("corrupt")
719        .bind("c")
720        .bind(&corrupt_blob)
721        .bind(payload_json)
722        .execute(store.pool())
723        .await
724        .unwrap();
725
726        // Search must not panic and must skip the corrupt point
727        let results = vs
728            .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
729            .await
730            .unwrap();
731        assert_eq!(results.len(), 1);
732        assert_eq!(results[0].id, "valid");
733    }
734}