Skip to main content

zeph_memory/
in_memory_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::RwLock;
8
9use crate::vector_store::{
10    FieldValue, ScoredVectorPoint, VectorFilter, VectorPoint, VectorStore, VectorStoreError,
11};
12
13type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
14
15struct StoredPoint {
16    vector: Vec<f32>,
17    payload: HashMap<String, serde_json::Value>,
18}
19
20struct InMemoryCollection {
21    points: HashMap<String, StoredPoint>,
22}
23
24pub struct InMemoryVectorStore {
25    collections: RwLock<HashMap<String, InMemoryCollection>>,
26}
27
28impl InMemoryVectorStore {
29    #[must_use]
30    pub fn new() -> Self {
31        Self {
32            collections: RwLock::new(HashMap::new()),
33        }
34    }
35}
36
37impl Default for InMemoryVectorStore {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl std::fmt::Debug for InMemoryVectorStore {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct("InMemoryVectorStore")
46            .finish_non_exhaustive()
47    }
48}
49
50use crate::math::cosine_similarity;
51
52fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
53    for cond in &filter.must {
54        let Some(val) = payload.get(&cond.field) else {
55            return false;
56        };
57        if !field_matches(val, &cond.value) {
58            return false;
59        }
60    }
61    for cond in &filter.must_not {
62        if let Some(val) = payload.get(&cond.field)
63            && field_matches(val, &cond.value)
64        {
65            return false;
66        }
67    }
68    true
69}
70
71fn field_matches(val: &serde_json::Value, expected: &FieldValue) -> bool {
72    match expected {
73        FieldValue::Integer(i) => val.as_i64() == Some(*i),
74        FieldValue::Text(s) => val.as_str() == Some(s.as_str()),
75    }
76}
77
78impl VectorStore for InMemoryVectorStore {
79    fn ensure_collection(
80        &self,
81        collection: &str,
82        _vector_size: u64,
83    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
84        let collection = collection.to_owned();
85        Box::pin(async move {
86            let mut cols = self
87                .collections
88                .write()
89                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
90            cols.entry(collection)
91                .or_insert_with(|| InMemoryCollection {
92                    points: HashMap::new(),
93                });
94            Ok(())
95        })
96    }
97
98    fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
99        let collection = collection.to_owned();
100        Box::pin(async move {
101            let cols = self
102                .collections
103                .read()
104                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
105            Ok(cols.contains_key(&collection))
106        })
107    }
108
109    fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
110        let collection = collection.to_owned();
111        Box::pin(async move {
112            let mut cols = self
113                .collections
114                .write()
115                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
116            cols.remove(&collection);
117            Ok(())
118        })
119    }
120
121    fn upsert(
122        &self,
123        collection: &str,
124        points: Vec<VectorPoint>,
125    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
126        let collection = collection.to_owned();
127        Box::pin(async move {
128            let mut cols = self
129                .collections
130                .write()
131                .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
132            let col = cols.get_mut(&collection).ok_or_else(|| {
133                VectorStoreError::Upsert(format!("collection {collection} not found"))
134            })?;
135            for p in points {
136                col.points.insert(
137                    p.id,
138                    StoredPoint {
139                        vector: p.vector,
140                        payload: p.payload,
141                    },
142                );
143            }
144            Ok(())
145        })
146    }
147
148    fn search(
149        &self,
150        collection: &str,
151        vector: Vec<f32>,
152        limit: u64,
153        filter: Option<VectorFilter>,
154    ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
155        let collection = collection.to_owned();
156        Box::pin(async move {
157            let cols = self
158                .collections
159                .read()
160                .map_err(|e| VectorStoreError::Search(e.to_string()))?;
161            let col = cols.get(&collection).ok_or_else(|| {
162                VectorStoreError::Search(format!("collection {collection} not found"))
163            })?;
164
165            let empty_filter = VectorFilter::default();
166            let f = filter.as_ref().unwrap_or(&empty_filter);
167
168            let mut scored: Vec<ScoredVectorPoint> = col
169                .points
170                .iter()
171                .filter(|(_, sp)| matches_filter(&sp.payload, f))
172                .map(|(id, sp)| ScoredVectorPoint {
173                    id: id.clone(),
174                    score: cosine_similarity(&vector, &sp.vector),
175                    payload: sp.payload.clone(),
176                })
177                .collect();
178
179            scored.sort_by(|a, b| {
180                b.score
181                    .partial_cmp(&a.score)
182                    .unwrap_or(std::cmp::Ordering::Equal)
183            });
184            #[expect(clippy::cast_possible_truncation)]
185            scored.truncate(limit as usize);
186            Ok(scored)
187        })
188    }
189
190    fn delete_by_ids(
191        &self,
192        collection: &str,
193        ids: Vec<String>,
194    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
195        let collection = collection.to_owned();
196        Box::pin(async move {
197            if ids.is_empty() {
198                return Ok(());
199            }
200            let mut cols = self
201                .collections
202                .write()
203                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
204            let col = cols.get_mut(&collection).ok_or_else(|| {
205                VectorStoreError::Delete(format!("collection {collection} not found"))
206            })?;
207            for id in &ids {
208                col.points.remove(id);
209            }
210            Ok(())
211        })
212    }
213
214    fn scroll_all(
215        &self,
216        collection: &str,
217        key_field: &str,
218    ) -> BoxFuture<'_, Result<HashMap<String, HashMap<String, String>>, VectorStoreError>> {
219        let collection = collection.to_owned();
220        let key_field = key_field.to_owned();
221        Box::pin(async move {
222            let cols = self
223                .collections
224                .read()
225                .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
226            let col = cols.get(&collection).ok_or_else(|| {
227                VectorStoreError::Scroll(format!("collection {collection} not found"))
228            })?;
229
230            let mut result = HashMap::new();
231            for sp in col.points.values() {
232                let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
233                    continue;
234                };
235                let mut fields = HashMap::new();
236                for (k, v) in &sp.payload {
237                    if let Some(s) = v.as_str() {
238                        fields.insert(k.clone(), s.to_owned());
239                    }
240                }
241                result.insert(key_val.to_owned(), fields);
242            }
243            Ok(result)
244        })
245    }
246
247    fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
248        Box::pin(async { Ok(true) })
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[tokio::test]
257    async fn ensure_collection_and_exists() {
258        let store = InMemoryVectorStore::new();
259        assert!(!store.collection_exists("test").await.unwrap());
260        store.ensure_collection("test", 3).await.unwrap();
261        assert!(store.collection_exists("test").await.unwrap());
262    }
263
264    #[tokio::test]
265    async fn ensure_collection_idempotent() {
266        let store = InMemoryVectorStore::new();
267        store.ensure_collection("test", 3).await.unwrap();
268        store.ensure_collection("test", 3).await.unwrap();
269        assert!(store.collection_exists("test").await.unwrap());
270    }
271
272    #[tokio::test]
273    async fn delete_collection_removes() {
274        let store = InMemoryVectorStore::new();
275        store.ensure_collection("test", 3).await.unwrap();
276        store.delete_collection("test").await.unwrap();
277        assert!(!store.collection_exists("test").await.unwrap());
278    }
279
280    #[tokio::test]
281    async fn upsert_and_search() {
282        let store = InMemoryVectorStore::new();
283        store.ensure_collection("test", 3).await.unwrap();
284
285        let points = vec![
286            VectorPoint {
287                id: "a".into(),
288                vector: vec![1.0, 0.0, 0.0],
289                payload: HashMap::from([("name".into(), serde_json::json!("alpha"))]),
290            },
291            VectorPoint {
292                id: "b".into(),
293                vector: vec![0.0, 1.0, 0.0],
294                payload: HashMap::from([("name".into(), serde_json::json!("beta"))]),
295            },
296        ];
297        store.upsert("test", points).await.unwrap();
298
299        let results = store
300            .search("test", vec![1.0, 0.0, 0.0], 2, None)
301            .await
302            .unwrap();
303        assert_eq!(results.len(), 2);
304        assert_eq!(results[0].id, "a");
305        assert!((results[0].score - 1.0).abs() < f32::EPSILON);
306    }
307
308    #[tokio::test]
309    async fn search_with_filter() {
310        let store = InMemoryVectorStore::new();
311        store.ensure_collection("test", 3).await.unwrap();
312
313        let points = vec![
314            VectorPoint {
315                id: "a".into(),
316                vector: vec![1.0, 0.0, 0.0],
317                payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
318            },
319            VectorPoint {
320                id: "b".into(),
321                vector: vec![0.9, 0.1, 0.0],
322                payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
323            },
324        ];
325        store.upsert("test", points).await.unwrap();
326
327        let filter = VectorFilter {
328            must: vec![crate::vector_store::FieldCondition {
329                field: "role".into(),
330                value: FieldValue::Text("user".into()),
331            }],
332            must_not: vec![],
333        };
334        let results = store
335            .search("test", vec![1.0, 0.0, 0.0], 10, Some(filter))
336            .await
337            .unwrap();
338        assert_eq!(results.len(), 1);
339        assert_eq!(results[0].id, "a");
340    }
341
342    #[tokio::test]
343    async fn delete_by_ids_removes_points() {
344        let store = InMemoryVectorStore::new();
345        store.ensure_collection("test", 3).await.unwrap();
346
347        let points = vec![VectorPoint {
348            id: "a".into(),
349            vector: vec![1.0, 0.0, 0.0],
350            payload: HashMap::new(),
351        }];
352        store.upsert("test", points).await.unwrap();
353        store.delete_by_ids("test", vec!["a".into()]).await.unwrap();
354
355        let results = store
356            .search("test", vec![1.0, 0.0, 0.0], 10, None)
357            .await
358            .unwrap();
359        assert!(results.is_empty());
360    }
361
362    #[tokio::test]
363    async fn scroll_all_extracts_strings() {
364        let store = InMemoryVectorStore::new();
365        store.ensure_collection("test", 3).await.unwrap();
366
367        let points = vec![VectorPoint {
368            id: "a".into(),
369            vector: vec![1.0, 0.0, 0.0],
370            payload: HashMap::from([
371                ("name".into(), serde_json::json!("alpha")),
372                ("desc".into(), serde_json::json!("first")),
373                ("num".into(), serde_json::json!(42)),
374            ]),
375        }];
376        store.upsert("test", points).await.unwrap();
377
378        let result = store.scroll_all("test", "name").await.unwrap();
379        assert_eq!(result.len(), 1);
380        let fields = result.get("alpha").unwrap();
381        assert_eq!(fields.get("desc").unwrap(), "first");
382        assert!(!fields.contains_key("num"));
383    }
384
385    #[test]
386    fn cosine_similarity_import_wired() {
387        // Smoke test: verifies the re-export binding is intact. Edge-case coverage is in math.rs.
388        assert!(!cosine_similarity(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).is_nan());
389    }
390
391    #[tokio::test]
392    async fn default_impl() {
393        let store = InMemoryVectorStore::default();
394        assert!(!store.collection_exists("any").await.unwrap());
395    }
396
397    #[test]
398    fn debug_format() {
399        let store = InMemoryVectorStore::new();
400        let dbg = format!("{store:?}");
401        assert!(dbg.contains("InMemoryVectorStore"));
402    }
403}