Skip to main content

zeph_memory/
in_memory_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Purely in-memory [`VectorStore`] implementation for unit tests.
5//!
6//! All data lives in a `RwLock<HashMap>` and is discarded when the store is dropped.
7//! Cosine similarity is computed with [`zeph_common::math::cosine_similarity`].
8//! Not suitable for production — use [`crate::qdrant_ops::QdrantOps`] instead.
9
10use std::collections::HashMap;
11
12use parking_lot::RwLock;
13
14use crate::vector_store::{
15    BoxFuture, FieldValue, ScoredVectorPoint, ScrollWithIdsResult, VectorFilter, VectorPoint,
16    VectorStore, VectorStoreError,
17};
18
19struct StoredPoint {
20    vector: Vec<f32>,
21    payload: HashMap<String, serde_json::Value>,
22}
23
24struct InMemoryCollection {
25    points: HashMap<String, StoredPoint>,
26}
27
28/// In-process vector store backed by a `RwLock<HashMap>`.
29///
30/// Intended for unit tests only.  All data is lost when the store is dropped.
31///
32/// # Examples
33///
34/// ```
35/// use zeph_memory::in_memory_store::InMemoryVectorStore;
36///
37/// let store = InMemoryVectorStore::new();
38/// // Use `store` as a `VectorStore` in tests.
39/// ```
40pub struct InMemoryVectorStore {
41    collections: RwLock<HashMap<String, InMemoryCollection>>,
42}
43
44impl InMemoryVectorStore {
45    #[must_use]
46    pub fn new() -> Self {
47        Self {
48            collections: RwLock::new(HashMap::new()),
49        }
50    }
51}
52
53impl Default for InMemoryVectorStore {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl std::fmt::Debug for InMemoryVectorStore {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("InMemoryVectorStore")
62            .finish_non_exhaustive()
63    }
64}
65
66use zeph_common::math::cosine_similarity;
67
68fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
69    for cond in &filter.must {
70        let Some(val) = payload.get(&cond.field) else {
71            return false;
72        };
73        if !field_matches(val, &cond.value) {
74            return false;
75        }
76    }
77    for cond in &filter.must_not {
78        if let Some(val) = payload.get(&cond.field)
79            && field_matches(val, &cond.value)
80        {
81            return false;
82        }
83    }
84    true
85}
86
87fn field_matches(val: &serde_json::Value, expected: &FieldValue) -> bool {
88    match expected {
89        FieldValue::Integer(i) => val.as_i64() == Some(*i),
90        FieldValue::Text(s) => val.as_str() == Some(s.as_str()),
91    }
92}
93
94impl VectorStore for InMemoryVectorStore {
95    fn ensure_collection(
96        &self,
97        collection: &str,
98        _vector_size: u64,
99    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
100        let collection = collection.to_owned();
101        Box::pin(async move {
102            let mut cols = self.collections.write();
103            cols.entry(collection)
104                .or_insert_with(|| InMemoryCollection {
105                    points: HashMap::new(),
106                });
107            Ok(())
108        })
109    }
110
111    fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
112        let collection = collection.to_owned();
113        Box::pin(async move {
114            let cols = self.collections.read();
115            Ok(cols.contains_key(&collection))
116        })
117    }
118
119    fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
120        let collection = collection.to_owned();
121        Box::pin(async move {
122            let mut cols = self.collections.write();
123            cols.remove(&collection);
124            Ok(())
125        })
126    }
127
128    fn upsert(
129        &self,
130        collection: &str,
131        points: Vec<VectorPoint>,
132    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
133        let collection = collection.to_owned();
134        Box::pin(async move {
135            let mut cols = self.collections.write();
136            let col = cols.get_mut(&collection).ok_or_else(|| {
137                VectorStoreError::Upsert(format!("collection {collection} not found"))
138            })?;
139            for p in points {
140                col.points.insert(
141                    p.id,
142                    StoredPoint {
143                        vector: p.vector,
144                        payload: p.payload,
145                    },
146                );
147            }
148            Ok(())
149        })
150    }
151
152    fn search(
153        &self,
154        collection: &str,
155        vector: Vec<f32>,
156        limit: u64,
157        filter: Option<VectorFilter>,
158    ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
159        let collection = collection.to_owned();
160        Box::pin(async move {
161            let cols = self.collections.read();
162            let col = cols.get(&collection).ok_or_else(|| {
163                VectorStoreError::Search(format!("collection {collection} not found"))
164            })?;
165
166            let empty_filter = VectorFilter::default();
167            let f = filter.as_ref().unwrap_or(&empty_filter);
168
169            let mut scored: Vec<ScoredVectorPoint> = col
170                .points
171                .iter()
172                .filter(|(_, sp)| matches_filter(&sp.payload, f))
173                .map(|(id, sp)| ScoredVectorPoint {
174                    id: id.clone(),
175                    score: cosine_similarity(&vector, &sp.vector),
176                    payload: sp.payload.clone(),
177                })
178                .collect();
179
180            scored.sort_by(|a, b| {
181                b.score
182                    .partial_cmp(&a.score)
183                    .unwrap_or(std::cmp::Ordering::Equal)
184            });
185            #[expect(clippy::cast_possible_truncation)]
186            scored.truncate(limit as usize);
187            Ok(scored)
188        })
189    }
190
191    fn delete_by_ids(
192        &self,
193        collection: &str,
194        ids: Vec<String>,
195    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
196        let collection = collection.to_owned();
197        Box::pin(async move {
198            if ids.is_empty() {
199                return Ok(());
200            }
201            let mut cols = self.collections.write();
202            let col = cols.get_mut(&collection).ok_or_else(|| {
203                VectorStoreError::Delete(format!("collection {collection} not found"))
204            })?;
205            for id in &ids {
206                col.points.remove(id);
207            }
208            Ok(())
209        })
210    }
211
212    fn scroll_all(
213        &self,
214        collection: &str,
215        key_field: &str,
216    ) -> BoxFuture<'_, Result<HashMap<String, HashMap<String, String>>, VectorStoreError>> {
217        let collection = collection.to_owned();
218        let key_field = key_field.to_owned();
219        Box::pin(async move {
220            let cols = self.collections.read();
221            let col = cols.get(&collection).ok_or_else(|| {
222                VectorStoreError::Scroll(format!("collection {collection} not found"))
223            })?;
224
225            let mut result = HashMap::new();
226            for sp in col.points.values() {
227                let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
228                    continue;
229                };
230                let mut fields = HashMap::new();
231                for (k, v) in &sp.payload {
232                    if let Some(s) = v.as_str() {
233                        fields.insert(k.clone(), s.to_owned());
234                    }
235                }
236                result.insert(key_val.to_owned(), fields);
237            }
238            Ok(result)
239        })
240    }
241
242    fn scroll_all_with_point_ids(
243        &self,
244        collection: &str,
245        key_field: &str,
246    ) -> BoxFuture<'_, Result<ScrollWithIdsResult, VectorStoreError>> {
247        let collection = collection.to_owned();
248        let key_field = key_field.to_owned();
249        Box::pin(async move {
250            let cols = self.collections.read();
251            let col = cols.get(&collection).ok_or_else(|| {
252                VectorStoreError::Scroll(format!("collection {collection} not found"))
253            })?;
254
255            let mut result = Vec::new();
256            for (point_id, sp) in &col.points {
257                let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
258                    continue;
259                };
260                let mut fields = HashMap::new();
261                for (k, v) in &sp.payload {
262                    if let Some(s) = v.as_str() {
263                        fields.insert(k.clone(), s.to_owned());
264                    }
265                }
266                // Ensure the key_field value is always present in the fields map.
267                fields.insert(key_field.clone(), key_val.to_owned());
268                result.push((point_id.clone(), fields));
269            }
270            Ok(result)
271        })
272    }
273
274    fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
275        Box::pin(async { Ok(true) })
276    }
277
278    fn get_points(
279        &self,
280        collection: &str,
281        ids: Vec<String>,
282    ) -> BoxFuture<'_, Result<Vec<VectorPoint>, VectorStoreError>> {
283        let collection = collection.to_owned();
284        Box::pin(async move {
285            let cols = self.collections.read();
286            let col = cols.get(&collection).ok_or_else(|| {
287                VectorStoreError::Unsupported(format!("collection {collection} not found"))
288            })?;
289            let points = ids
290                .into_iter()
291                .filter_map(|id| {
292                    col.points.get(&id).map(|sp| VectorPoint {
293                        id: id.clone(),
294                        vector: sp.vector.clone(),
295                        payload: sp.payload.clone(),
296                    })
297                })
298                .collect();
299            Ok(points)
300        })
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[tokio::test]
309    async fn ensure_collection_and_exists() {
310        let store = InMemoryVectorStore::new();
311        assert!(!store.collection_exists("test").await.unwrap());
312        store.ensure_collection("test", 3).await.unwrap();
313        assert!(store.collection_exists("test").await.unwrap());
314    }
315
316    #[tokio::test]
317    async fn ensure_collection_idempotent() {
318        let store = InMemoryVectorStore::new();
319        store.ensure_collection("test", 3).await.unwrap();
320        store.ensure_collection("test", 3).await.unwrap();
321        assert!(store.collection_exists("test").await.unwrap());
322    }
323
324    #[tokio::test]
325    async fn delete_collection_removes() {
326        let store = InMemoryVectorStore::new();
327        store.ensure_collection("test", 3).await.unwrap();
328        store.delete_collection("test").await.unwrap();
329        assert!(!store.collection_exists("test").await.unwrap());
330    }
331
332    #[tokio::test]
333    async fn upsert_and_search() {
334        let store = InMemoryVectorStore::new();
335        store.ensure_collection("test", 3).await.unwrap();
336
337        let points = vec![
338            VectorPoint {
339                id: "a".into(),
340                vector: vec![1.0, 0.0, 0.0],
341                payload: HashMap::from([("name".into(), serde_json::json!("alpha"))]),
342            },
343            VectorPoint {
344                id: "b".into(),
345                vector: vec![0.0, 1.0, 0.0],
346                payload: HashMap::from([("name".into(), serde_json::json!("beta"))]),
347            },
348        ];
349        store.upsert("test", points).await.unwrap();
350
351        let results = store
352            .search("test", vec![1.0, 0.0, 0.0], 2, None)
353            .await
354            .unwrap();
355        assert_eq!(results.len(), 2);
356        assert_eq!(results[0].id, "a");
357        assert!((results[0].score - 1.0).abs() < f32::EPSILON);
358    }
359
360    #[tokio::test]
361    async fn search_with_filter() {
362        let store = InMemoryVectorStore::new();
363        store.ensure_collection("test", 3).await.unwrap();
364
365        let points = vec![
366            VectorPoint {
367                id: "a".into(),
368                vector: vec![1.0, 0.0, 0.0],
369                payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
370            },
371            VectorPoint {
372                id: "b".into(),
373                vector: vec![0.9, 0.1, 0.0],
374                payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
375            },
376        ];
377        store.upsert("test", points).await.unwrap();
378
379        let filter = VectorFilter {
380            must: vec![crate::vector_store::FieldCondition {
381                field: "role".into(),
382                value: FieldValue::Text("user".into()),
383            }],
384            must_not: vec![],
385        };
386        let results = store
387            .search("test", vec![1.0, 0.0, 0.0], 10, Some(filter))
388            .await
389            .unwrap();
390        assert_eq!(results.len(), 1);
391        assert_eq!(results[0].id, "a");
392    }
393
394    #[tokio::test]
395    async fn delete_by_ids_removes_points() {
396        let store = InMemoryVectorStore::new();
397        store.ensure_collection("test", 3).await.unwrap();
398
399        let points = vec![VectorPoint {
400            id: "a".into(),
401            vector: vec![1.0, 0.0, 0.0],
402            payload: HashMap::new(),
403        }];
404        store.upsert("test", points).await.unwrap();
405        store.delete_by_ids("test", vec!["a".into()]).await.unwrap();
406
407        let results = store
408            .search("test", vec![1.0, 0.0, 0.0], 10, None)
409            .await
410            .unwrap();
411        assert!(results.is_empty());
412    }
413
414    #[tokio::test]
415    async fn scroll_all_extracts_strings() {
416        let store = InMemoryVectorStore::new();
417        store.ensure_collection("test", 3).await.unwrap();
418
419        let points = vec![VectorPoint {
420            id: "a".into(),
421            vector: vec![1.0, 0.0, 0.0],
422            payload: HashMap::from([
423                ("name".into(), serde_json::json!("alpha")),
424                ("desc".into(), serde_json::json!("first")),
425                ("num".into(), serde_json::json!(42)),
426            ]),
427        }];
428        store.upsert("test", points).await.unwrap();
429
430        let result = store.scroll_all("test", "name").await.unwrap();
431        assert_eq!(result.len(), 1);
432        let fields = result.get("alpha").unwrap();
433        assert_eq!(fields.get("desc").unwrap(), "first");
434        assert!(!fields.contains_key("num"));
435    }
436
437    #[tokio::test]
438    async fn scroll_all_with_point_ids_returns_point_id() {
439        let store = InMemoryVectorStore::new();
440        store.ensure_collection("test", 3).await.unwrap();
441
442        let points = vec![
443            VectorPoint {
444                id: "pid-1".into(),
445                vector: vec![1.0, 0.0, 0.0],
446                payload: HashMap::from([
447                    ("entity_id_str".into(), serde_json::json!("42")),
448                    ("name".into(), serde_json::json!("Alpha")),
449                    ("count".into(), serde_json::json!(7)), // non-string, must be excluded
450                ]),
451            },
452            VectorPoint {
453                id: "pid-2".into(),
454                vector: vec![0.0, 1.0, 0.0],
455                // Missing key_field: must be excluded from results.
456                payload: HashMap::from([("name".into(), serde_json::json!("Beta"))]),
457            },
458        ];
459        store.upsert("test", points).await.unwrap();
460
461        let result = store
462            .scroll_all_with_point_ids("test", "entity_id_str")
463            .await
464            .unwrap();
465
466        assert_eq!(
467            result.len(),
468            1,
469            "only the point with key_field should appear"
470        );
471        let (point_id, fields) = &result[0];
472        assert_eq!(point_id, "pid-1");
473        assert_eq!(fields.get("entity_id_str").map(String::as_str), Some("42"));
474        assert_eq!(fields.get("name").map(String::as_str), Some("Alpha"));
475        // Non-string field must be absent.
476        assert!(!fields.contains_key("count"));
477    }
478
479    #[test]
480    fn cosine_similarity_import_wired() {
481        // Smoke test: verifies the re-export binding is intact. Edge-case coverage is in math.rs.
482        assert!(!cosine_similarity(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).is_nan());
483    }
484
485    #[tokio::test]
486    async fn default_impl() {
487        let store = InMemoryVectorStore::default();
488        assert!(!store.collection_exists("any").await.unwrap());
489    }
490
491    #[test]
492    fn debug_format() {
493        let store = InMemoryVectorStore::new();
494        let dbg = format!("{store:?}");
495        assert!(dbg.contains("InMemoryVectorStore"));
496    }
497}