Skip to main content

zeph_memory/
in_memory_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5use std::sync::RwLock;
6
7use crate::vector_store::{
8    BoxFuture, FieldValue, ScoredVectorPoint, VectorFilter, VectorPoint, VectorStore,
9    VectorStoreError,
10};
11
12struct StoredPoint {
13    vector: Vec<f32>,
14    payload: HashMap<String, serde_json::Value>,
15}
16
17struct InMemoryCollection {
18    points: HashMap<String, StoredPoint>,
19}
20
21pub struct InMemoryVectorStore {
22    collections: RwLock<HashMap<String, InMemoryCollection>>,
23}
24
25impl InMemoryVectorStore {
26    #[must_use]
27    pub fn new() -> Self {
28        Self {
29            collections: RwLock::new(HashMap::new()),
30        }
31    }
32}
33
34impl Default for InMemoryVectorStore {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl std::fmt::Debug for InMemoryVectorStore {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        f.debug_struct("InMemoryVectorStore")
43            .finish_non_exhaustive()
44    }
45}
46
47use crate::math::cosine_similarity;
48
49fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
50    for cond in &filter.must {
51        let Some(val) = payload.get(&cond.field) else {
52            return false;
53        };
54        if !field_matches(val, &cond.value) {
55            return false;
56        }
57    }
58    for cond in &filter.must_not {
59        if let Some(val) = payload.get(&cond.field)
60            && field_matches(val, &cond.value)
61        {
62            return false;
63        }
64    }
65    true
66}
67
68fn field_matches(val: &serde_json::Value, expected: &FieldValue) -> bool {
69    match expected {
70        FieldValue::Integer(i) => val.as_i64() == Some(*i),
71        FieldValue::Text(s) => val.as_str() == Some(s.as_str()),
72    }
73}
74
75impl VectorStore for InMemoryVectorStore {
76    fn ensure_collection(
77        &self,
78        collection: &str,
79        _vector_size: u64,
80    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
81        let collection = collection.to_owned();
82        Box::pin(async move {
83            let mut cols = self
84                .collections
85                .write()
86                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
87            cols.entry(collection)
88                .or_insert_with(|| InMemoryCollection {
89                    points: HashMap::new(),
90                });
91            Ok(())
92        })
93    }
94
95    fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
96        let collection = collection.to_owned();
97        Box::pin(async move {
98            let cols = self
99                .collections
100                .read()
101                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
102            Ok(cols.contains_key(&collection))
103        })
104    }
105
106    fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
107        let collection = collection.to_owned();
108        Box::pin(async move {
109            let mut cols = self
110                .collections
111                .write()
112                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
113            cols.remove(&collection);
114            Ok(())
115        })
116    }
117
118    fn upsert(
119        &self,
120        collection: &str,
121        points: Vec<VectorPoint>,
122    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
123        let collection = collection.to_owned();
124        Box::pin(async move {
125            let mut cols = self
126                .collections
127                .write()
128                .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
129            let col = cols.get_mut(&collection).ok_or_else(|| {
130                VectorStoreError::Upsert(format!("collection {collection} not found"))
131            })?;
132            for p in points {
133                col.points.insert(
134                    p.id,
135                    StoredPoint {
136                        vector: p.vector,
137                        payload: p.payload,
138                    },
139                );
140            }
141            Ok(())
142        })
143    }
144
145    fn search(
146        &self,
147        collection: &str,
148        vector: Vec<f32>,
149        limit: u64,
150        filter: Option<VectorFilter>,
151    ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
152        let collection = collection.to_owned();
153        Box::pin(async move {
154            let cols = self
155                .collections
156                .read()
157                .map_err(|e| VectorStoreError::Search(e.to_string()))?;
158            let col = cols.get(&collection).ok_or_else(|| {
159                VectorStoreError::Search(format!("collection {collection} not found"))
160            })?;
161
162            let empty_filter = VectorFilter::default();
163            let f = filter.as_ref().unwrap_or(&empty_filter);
164
165            let mut scored: Vec<ScoredVectorPoint> = col
166                .points
167                .iter()
168                .filter(|(_, sp)| matches_filter(&sp.payload, f))
169                .map(|(id, sp)| ScoredVectorPoint {
170                    id: id.clone(),
171                    score: cosine_similarity(&vector, &sp.vector),
172                    payload: sp.payload.clone(),
173                })
174                .collect();
175
176            scored.sort_by(|a, b| {
177                b.score
178                    .partial_cmp(&a.score)
179                    .unwrap_or(std::cmp::Ordering::Equal)
180            });
181            #[expect(clippy::cast_possible_truncation)]
182            scored.truncate(limit as usize);
183            Ok(scored)
184        })
185    }
186
187    fn delete_by_ids(
188        &self,
189        collection: &str,
190        ids: Vec<String>,
191    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
192        let collection = collection.to_owned();
193        Box::pin(async move {
194            if ids.is_empty() {
195                return Ok(());
196            }
197            let mut cols = self
198                .collections
199                .write()
200                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
201            let col = cols.get_mut(&collection).ok_or_else(|| {
202                VectorStoreError::Delete(format!("collection {collection} not found"))
203            })?;
204            for id in &ids {
205                col.points.remove(id);
206            }
207            Ok(())
208        })
209    }
210
211    fn scroll_all(
212        &self,
213        collection: &str,
214        key_field: &str,
215    ) -> BoxFuture<'_, Result<HashMap<String, HashMap<String, String>>, VectorStoreError>> {
216        let collection = collection.to_owned();
217        let key_field = key_field.to_owned();
218        Box::pin(async move {
219            let cols = self
220                .collections
221                .read()
222                .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
223            let col = cols.get(&collection).ok_or_else(|| {
224                VectorStoreError::Scroll(format!("collection {collection} not found"))
225            })?;
226
227            let mut result = HashMap::new();
228            for sp in col.points.values() {
229                let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
230                    continue;
231                };
232                let mut fields = HashMap::new();
233                for (k, v) in &sp.payload {
234                    if let Some(s) = v.as_str() {
235                        fields.insert(k.clone(), s.to_owned());
236                    }
237                }
238                result.insert(key_val.to_owned(), fields);
239            }
240            Ok(result)
241        })
242    }
243
244    fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
245        Box::pin(async { Ok(true) })
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[tokio::test]
254    async fn ensure_collection_and_exists() {
255        let store = InMemoryVectorStore::new();
256        assert!(!store.collection_exists("test").await.unwrap());
257        store.ensure_collection("test", 3).await.unwrap();
258        assert!(store.collection_exists("test").await.unwrap());
259    }
260
261    #[tokio::test]
262    async fn ensure_collection_idempotent() {
263        let store = InMemoryVectorStore::new();
264        store.ensure_collection("test", 3).await.unwrap();
265        store.ensure_collection("test", 3).await.unwrap();
266        assert!(store.collection_exists("test").await.unwrap());
267    }
268
269    #[tokio::test]
270    async fn delete_collection_removes() {
271        let store = InMemoryVectorStore::new();
272        store.ensure_collection("test", 3).await.unwrap();
273        store.delete_collection("test").await.unwrap();
274        assert!(!store.collection_exists("test").await.unwrap());
275    }
276
277    #[tokio::test]
278    async fn upsert_and_search() {
279        let store = InMemoryVectorStore::new();
280        store.ensure_collection("test", 3).await.unwrap();
281
282        let points = vec![
283            VectorPoint {
284                id: "a".into(),
285                vector: vec![1.0, 0.0, 0.0],
286                payload: HashMap::from([("name".into(), serde_json::json!("alpha"))]),
287            },
288            VectorPoint {
289                id: "b".into(),
290                vector: vec![0.0, 1.0, 0.0],
291                payload: HashMap::from([("name".into(), serde_json::json!("beta"))]),
292            },
293        ];
294        store.upsert("test", points).await.unwrap();
295
296        let results = store
297            .search("test", vec![1.0, 0.0, 0.0], 2, None)
298            .await
299            .unwrap();
300        assert_eq!(results.len(), 2);
301        assert_eq!(results[0].id, "a");
302        assert!((results[0].score - 1.0).abs() < f32::EPSILON);
303    }
304
305    #[tokio::test]
306    async fn search_with_filter() {
307        let store = InMemoryVectorStore::new();
308        store.ensure_collection("test", 3).await.unwrap();
309
310        let points = vec![
311            VectorPoint {
312                id: "a".into(),
313                vector: vec![1.0, 0.0, 0.0],
314                payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
315            },
316            VectorPoint {
317                id: "b".into(),
318                vector: vec![0.9, 0.1, 0.0],
319                payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
320            },
321        ];
322        store.upsert("test", points).await.unwrap();
323
324        let filter = VectorFilter {
325            must: vec![crate::vector_store::FieldCondition {
326                field: "role".into(),
327                value: FieldValue::Text("user".into()),
328            }],
329            must_not: vec![],
330        };
331        let results = store
332            .search("test", vec![1.0, 0.0, 0.0], 10, Some(filter))
333            .await
334            .unwrap();
335        assert_eq!(results.len(), 1);
336        assert_eq!(results[0].id, "a");
337    }
338
339    #[tokio::test]
340    async fn delete_by_ids_removes_points() {
341        let store = InMemoryVectorStore::new();
342        store.ensure_collection("test", 3).await.unwrap();
343
344        let points = vec![VectorPoint {
345            id: "a".into(),
346            vector: vec![1.0, 0.0, 0.0],
347            payload: HashMap::new(),
348        }];
349        store.upsert("test", points).await.unwrap();
350        store.delete_by_ids("test", vec!["a".into()]).await.unwrap();
351
352        let results = store
353            .search("test", vec![1.0, 0.0, 0.0], 10, None)
354            .await
355            .unwrap();
356        assert!(results.is_empty());
357    }
358
359    #[tokio::test]
360    async fn scroll_all_extracts_strings() {
361        let store = InMemoryVectorStore::new();
362        store.ensure_collection("test", 3).await.unwrap();
363
364        let points = vec![VectorPoint {
365            id: "a".into(),
366            vector: vec![1.0, 0.0, 0.0],
367            payload: HashMap::from([
368                ("name".into(), serde_json::json!("alpha")),
369                ("desc".into(), serde_json::json!("first")),
370                ("num".into(), serde_json::json!(42)),
371            ]),
372        }];
373        store.upsert("test", points).await.unwrap();
374
375        let result = store.scroll_all("test", "name").await.unwrap();
376        assert_eq!(result.len(), 1);
377        let fields = result.get("alpha").unwrap();
378        assert_eq!(fields.get("desc").unwrap(), "first");
379        assert!(!fields.contains_key("num"));
380    }
381
382    #[test]
383    fn cosine_similarity_import_wired() {
384        // Smoke test: verifies the re-export binding is intact. Edge-case coverage is in math.rs.
385        assert!(!cosine_similarity(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).is_nan());
386    }
387
388    #[tokio::test]
389    async fn default_impl() {
390        let store = InMemoryVectorStore::default();
391        assert!(!store.collection_exists("any").await.unwrap());
392    }
393
394    #[test]
395    fn debug_format() {
396        let store = InMemoryVectorStore::new();
397        let dbg = format!("{store:?}");
398        assert!(dbg.contains("InMemoryVectorStore"));
399    }
400}