Skip to main content

zeph_memory/
in_memory_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Purely in-memory [`VectorStore`] implementation for unit tests.
5//!
6//! All data lives in a `RwLock<HashMap>` and is discarded when the store is dropped.
7//! Cosine similarity is computed with [`zeph_common::math::cosine_similarity`].
8//! Not suitable for production — use [`crate::qdrant_ops::QdrantOps`] instead.
9
10use std::collections::HashMap;
11
12use parking_lot::RwLock;
13
14use crate::vector_store::{
15    BoxFuture, FieldValue, ScoredVectorPoint, VectorFilter, VectorPoint, VectorStore,
16    VectorStoreError,
17};
18
19struct StoredPoint {
20    vector: Vec<f32>,
21    payload: HashMap<String, serde_json::Value>,
22}
23
24struct InMemoryCollection {
25    points: HashMap<String, StoredPoint>,
26}
27
28/// In-process vector store backed by a `RwLock<HashMap>`.
29///
30/// Intended for unit tests only.  All data is lost when the store is dropped.
31///
32/// # Examples
33///
34/// ```
35/// use zeph_memory::in_memory_store::InMemoryVectorStore;
36///
37/// let store = InMemoryVectorStore::new();
38/// // Use `store` as a `VectorStore` in tests.
39/// ```
40pub struct InMemoryVectorStore {
41    collections: RwLock<HashMap<String, InMemoryCollection>>,
42}
43
44impl InMemoryVectorStore {
45    #[must_use]
46    pub fn new() -> Self {
47        Self {
48            collections: RwLock::new(HashMap::new()),
49        }
50    }
51}
52
53impl Default for InMemoryVectorStore {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl std::fmt::Debug for InMemoryVectorStore {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("InMemoryVectorStore")
62            .finish_non_exhaustive()
63    }
64}
65
66use zeph_common::math::cosine_similarity;
67
68fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
69    for cond in &filter.must {
70        let Some(val) = payload.get(&cond.field) else {
71            return false;
72        };
73        if !field_matches(val, &cond.value) {
74            return false;
75        }
76    }
77    for cond in &filter.must_not {
78        if let Some(val) = payload.get(&cond.field)
79            && field_matches(val, &cond.value)
80        {
81            return false;
82        }
83    }
84    true
85}
86
87fn field_matches(val: &serde_json::Value, expected: &FieldValue) -> bool {
88    match expected {
89        FieldValue::Integer(i) => val.as_i64() == Some(*i),
90        FieldValue::Text(s) => val.as_str() == Some(s.as_str()),
91    }
92}
93
94impl VectorStore for InMemoryVectorStore {
95    fn ensure_collection(
96        &self,
97        collection: &str,
98        _vector_size: u64,
99    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
100        let collection = collection.to_owned();
101        Box::pin(async move {
102            let mut cols = self.collections.write();
103            cols.entry(collection)
104                .or_insert_with(|| InMemoryCollection {
105                    points: HashMap::new(),
106                });
107            Ok(())
108        })
109    }
110
111    fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
112        let collection = collection.to_owned();
113        Box::pin(async move {
114            let cols = self.collections.read();
115            Ok(cols.contains_key(&collection))
116        })
117    }
118
119    fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
120        let collection = collection.to_owned();
121        Box::pin(async move {
122            let mut cols = self.collections.write();
123            cols.remove(&collection);
124            Ok(())
125        })
126    }
127
128    fn upsert(
129        &self,
130        collection: &str,
131        points: Vec<VectorPoint>,
132    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
133        let collection = collection.to_owned();
134        Box::pin(async move {
135            let mut cols = self.collections.write();
136            let col = cols.get_mut(&collection).ok_or_else(|| {
137                VectorStoreError::Upsert(format!("collection {collection} not found"))
138            })?;
139            for p in points {
140                col.points.insert(
141                    p.id,
142                    StoredPoint {
143                        vector: p.vector,
144                        payload: p.payload,
145                    },
146                );
147            }
148            Ok(())
149        })
150    }
151
152    fn search(
153        &self,
154        collection: &str,
155        vector: Vec<f32>,
156        limit: u64,
157        filter: Option<VectorFilter>,
158    ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
159        let collection = collection.to_owned();
160        Box::pin(async move {
161            let cols = self.collections.read();
162            let col = cols.get(&collection).ok_or_else(|| {
163                VectorStoreError::Search(format!("collection {collection} not found"))
164            })?;
165
166            let empty_filter = VectorFilter::default();
167            let f = filter.as_ref().unwrap_or(&empty_filter);
168
169            let mut scored: Vec<ScoredVectorPoint> = col
170                .points
171                .iter()
172                .filter(|(_, sp)| matches_filter(&sp.payload, f))
173                .map(|(id, sp)| ScoredVectorPoint {
174                    id: id.clone(),
175                    score: cosine_similarity(&vector, &sp.vector),
176                    payload: sp.payload.clone(),
177                })
178                .collect();
179
180            scored.sort_by(|a, b| {
181                b.score
182                    .partial_cmp(&a.score)
183                    .unwrap_or(std::cmp::Ordering::Equal)
184            });
185            #[expect(clippy::cast_possible_truncation)]
186            scored.truncate(limit as usize);
187            Ok(scored)
188        })
189    }
190
191    fn delete_by_ids(
192        &self,
193        collection: &str,
194        ids: Vec<String>,
195    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
196        let collection = collection.to_owned();
197        Box::pin(async move {
198            if ids.is_empty() {
199                return Ok(());
200            }
201            let mut cols = self.collections.write();
202            let col = cols.get_mut(&collection).ok_or_else(|| {
203                VectorStoreError::Delete(format!("collection {collection} not found"))
204            })?;
205            for id in &ids {
206                col.points.remove(id);
207            }
208            Ok(())
209        })
210    }
211
212    fn scroll_all(
213        &self,
214        collection: &str,
215        key_field: &str,
216    ) -> BoxFuture<'_, Result<HashMap<String, HashMap<String, String>>, VectorStoreError>> {
217        let collection = collection.to_owned();
218        let key_field = key_field.to_owned();
219        Box::pin(async move {
220            let cols = self.collections.read();
221            let col = cols.get(&collection).ok_or_else(|| {
222                VectorStoreError::Scroll(format!("collection {collection} not found"))
223            })?;
224
225            let mut result = HashMap::new();
226            for sp in col.points.values() {
227                let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
228                    continue;
229                };
230                let mut fields = HashMap::new();
231                for (k, v) in &sp.payload {
232                    if let Some(s) = v.as_str() {
233                        fields.insert(k.clone(), s.to_owned());
234                    }
235                }
236                result.insert(key_val.to_owned(), fields);
237            }
238            Ok(result)
239        })
240    }
241
242    fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
243        Box::pin(async { Ok(true) })
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[tokio::test]
252    async fn ensure_collection_and_exists() {
253        let store = InMemoryVectorStore::new();
254        assert!(!store.collection_exists("test").await.unwrap());
255        store.ensure_collection("test", 3).await.unwrap();
256        assert!(store.collection_exists("test").await.unwrap());
257    }
258
259    #[tokio::test]
260    async fn ensure_collection_idempotent() {
261        let store = InMemoryVectorStore::new();
262        store.ensure_collection("test", 3).await.unwrap();
263        store.ensure_collection("test", 3).await.unwrap();
264        assert!(store.collection_exists("test").await.unwrap());
265    }
266
267    #[tokio::test]
268    async fn delete_collection_removes() {
269        let store = InMemoryVectorStore::new();
270        store.ensure_collection("test", 3).await.unwrap();
271        store.delete_collection("test").await.unwrap();
272        assert!(!store.collection_exists("test").await.unwrap());
273    }
274
275    #[tokio::test]
276    async fn upsert_and_search() {
277        let store = InMemoryVectorStore::new();
278        store.ensure_collection("test", 3).await.unwrap();
279
280        let points = vec![
281            VectorPoint {
282                id: "a".into(),
283                vector: vec![1.0, 0.0, 0.0],
284                payload: HashMap::from([("name".into(), serde_json::json!("alpha"))]),
285            },
286            VectorPoint {
287                id: "b".into(),
288                vector: vec![0.0, 1.0, 0.0],
289                payload: HashMap::from([("name".into(), serde_json::json!("beta"))]),
290            },
291        ];
292        store.upsert("test", points).await.unwrap();
293
294        let results = store
295            .search("test", vec![1.0, 0.0, 0.0], 2, None)
296            .await
297            .unwrap();
298        assert_eq!(results.len(), 2);
299        assert_eq!(results[0].id, "a");
300        assert!((results[0].score - 1.0).abs() < f32::EPSILON);
301    }
302
303    #[tokio::test]
304    async fn search_with_filter() {
305        let store = InMemoryVectorStore::new();
306        store.ensure_collection("test", 3).await.unwrap();
307
308        let points = vec![
309            VectorPoint {
310                id: "a".into(),
311                vector: vec![1.0, 0.0, 0.0],
312                payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
313            },
314            VectorPoint {
315                id: "b".into(),
316                vector: vec![0.9, 0.1, 0.0],
317                payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
318            },
319        ];
320        store.upsert("test", points).await.unwrap();
321
322        let filter = VectorFilter {
323            must: vec![crate::vector_store::FieldCondition {
324                field: "role".into(),
325                value: FieldValue::Text("user".into()),
326            }],
327            must_not: vec![],
328        };
329        let results = store
330            .search("test", vec![1.0, 0.0, 0.0], 10, Some(filter))
331            .await
332            .unwrap();
333        assert_eq!(results.len(), 1);
334        assert_eq!(results[0].id, "a");
335    }
336
337    #[tokio::test]
338    async fn delete_by_ids_removes_points() {
339        let store = InMemoryVectorStore::new();
340        store.ensure_collection("test", 3).await.unwrap();
341
342        let points = vec![VectorPoint {
343            id: "a".into(),
344            vector: vec![1.0, 0.0, 0.0],
345            payload: HashMap::new(),
346        }];
347        store.upsert("test", points).await.unwrap();
348        store.delete_by_ids("test", vec!["a".into()]).await.unwrap();
349
350        let results = store
351            .search("test", vec![1.0, 0.0, 0.0], 10, None)
352            .await
353            .unwrap();
354        assert!(results.is_empty());
355    }
356
357    #[tokio::test]
358    async fn scroll_all_extracts_strings() {
359        let store = InMemoryVectorStore::new();
360        store.ensure_collection("test", 3).await.unwrap();
361
362        let points = vec![VectorPoint {
363            id: "a".into(),
364            vector: vec![1.0, 0.0, 0.0],
365            payload: HashMap::from([
366                ("name".into(), serde_json::json!("alpha")),
367                ("desc".into(), serde_json::json!("first")),
368                ("num".into(), serde_json::json!(42)),
369            ]),
370        }];
371        store.upsert("test", points).await.unwrap();
372
373        let result = store.scroll_all("test", "name").await.unwrap();
374        assert_eq!(result.len(), 1);
375        let fields = result.get("alpha").unwrap();
376        assert_eq!(fields.get("desc").unwrap(), "first");
377        assert!(!fields.contains_key("num"));
378    }
379
380    #[test]
381    fn cosine_similarity_import_wired() {
382        // Smoke test: verifies the re-export binding is intact. Edge-case coverage is in math.rs.
383        assert!(!cosine_similarity(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).is_nan());
384    }
385
386    #[tokio::test]
387    async fn default_impl() {
388        let store = InMemoryVectorStore::default();
389        assert!(!store.collection_exists("any").await.unwrap());
390    }
391
392    #[test]
393    fn debug_format() {
394        let store = InMemoryVectorStore::new();
395        let dbg = format!("{store:?}");
396        assert!(dbg.contains("InMemoryVectorStore"));
397    }
398}