Skip to main content

zeph_memory/
in_memory_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5
6use parking_lot::RwLock;
7
8use crate::vector_store::{
9    BoxFuture, FieldValue, ScoredVectorPoint, VectorFilter, VectorPoint, VectorStore,
10    VectorStoreError,
11};
12
13struct StoredPoint {
14    vector: Vec<f32>,
15    payload: HashMap<String, serde_json::Value>,
16}
17
18struct InMemoryCollection {
19    points: HashMap<String, StoredPoint>,
20}
21
22pub struct InMemoryVectorStore {
23    collections: RwLock<HashMap<String, InMemoryCollection>>,
24}
25
26impl InMemoryVectorStore {
27    #[must_use]
28    pub fn new() -> Self {
29        Self {
30            collections: RwLock::new(HashMap::new()),
31        }
32    }
33}
34
35impl Default for InMemoryVectorStore {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl std::fmt::Debug for InMemoryVectorStore {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.debug_struct("InMemoryVectorStore")
44            .finish_non_exhaustive()
45    }
46}
47
48use zeph_common::math::cosine_similarity;
49
50fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
51    for cond in &filter.must {
52        let Some(val) = payload.get(&cond.field) else {
53            return false;
54        };
55        if !field_matches(val, &cond.value) {
56            return false;
57        }
58    }
59    for cond in &filter.must_not {
60        if let Some(val) = payload.get(&cond.field)
61            && field_matches(val, &cond.value)
62        {
63            return false;
64        }
65    }
66    true
67}
68
69fn field_matches(val: &serde_json::Value, expected: &FieldValue) -> bool {
70    match expected {
71        FieldValue::Integer(i) => val.as_i64() == Some(*i),
72        FieldValue::Text(s) => val.as_str() == Some(s.as_str()),
73    }
74}
75
76impl VectorStore for InMemoryVectorStore {
77    fn ensure_collection(
78        &self,
79        collection: &str,
80        _vector_size: u64,
81    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
82        let collection = collection.to_owned();
83        Box::pin(async move {
84            let mut cols = self.collections.write();
85            cols.entry(collection)
86                .or_insert_with(|| InMemoryCollection {
87                    points: HashMap::new(),
88                });
89            Ok(())
90        })
91    }
92
93    fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
94        let collection = collection.to_owned();
95        Box::pin(async move {
96            let cols = self.collections.read();
97            Ok(cols.contains_key(&collection))
98        })
99    }
100
101    fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
102        let collection = collection.to_owned();
103        Box::pin(async move {
104            let mut cols = self.collections.write();
105            cols.remove(&collection);
106            Ok(())
107        })
108    }
109
110    fn upsert(
111        &self,
112        collection: &str,
113        points: Vec<VectorPoint>,
114    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
115        let collection = collection.to_owned();
116        Box::pin(async move {
117            let mut cols = self.collections.write();
118            let col = cols.get_mut(&collection).ok_or_else(|| {
119                VectorStoreError::Upsert(format!("collection {collection} not found"))
120            })?;
121            for p in points {
122                col.points.insert(
123                    p.id,
124                    StoredPoint {
125                        vector: p.vector,
126                        payload: p.payload,
127                    },
128                );
129            }
130            Ok(())
131        })
132    }
133
134    fn search(
135        &self,
136        collection: &str,
137        vector: Vec<f32>,
138        limit: u64,
139        filter: Option<VectorFilter>,
140    ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
141        let collection = collection.to_owned();
142        Box::pin(async move {
143            let cols = self.collections.read();
144            let col = cols.get(&collection).ok_or_else(|| {
145                VectorStoreError::Search(format!("collection {collection} not found"))
146            })?;
147
148            let empty_filter = VectorFilter::default();
149            let f = filter.as_ref().unwrap_or(&empty_filter);
150
151            let mut scored: Vec<ScoredVectorPoint> = col
152                .points
153                .iter()
154                .filter(|(_, sp)| matches_filter(&sp.payload, f))
155                .map(|(id, sp)| ScoredVectorPoint {
156                    id: id.clone(),
157                    score: cosine_similarity(&vector, &sp.vector),
158                    payload: sp.payload.clone(),
159                })
160                .collect();
161
162            scored.sort_by(|a, b| {
163                b.score
164                    .partial_cmp(&a.score)
165                    .unwrap_or(std::cmp::Ordering::Equal)
166            });
167            #[expect(clippy::cast_possible_truncation)]
168            scored.truncate(limit as usize);
169            Ok(scored)
170        })
171    }
172
173    fn delete_by_ids(
174        &self,
175        collection: &str,
176        ids: Vec<String>,
177    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
178        let collection = collection.to_owned();
179        Box::pin(async move {
180            if ids.is_empty() {
181                return Ok(());
182            }
183            let mut cols = self.collections.write();
184            let col = cols.get_mut(&collection).ok_or_else(|| {
185                VectorStoreError::Delete(format!("collection {collection} not found"))
186            })?;
187            for id in &ids {
188                col.points.remove(id);
189            }
190            Ok(())
191        })
192    }
193
194    fn scroll_all(
195        &self,
196        collection: &str,
197        key_field: &str,
198    ) -> BoxFuture<'_, Result<HashMap<String, HashMap<String, String>>, VectorStoreError>> {
199        let collection = collection.to_owned();
200        let key_field = key_field.to_owned();
201        Box::pin(async move {
202            let cols = self.collections.read();
203            let col = cols.get(&collection).ok_or_else(|| {
204                VectorStoreError::Scroll(format!("collection {collection} not found"))
205            })?;
206
207            let mut result = HashMap::new();
208            for sp in col.points.values() {
209                let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
210                    continue;
211                };
212                let mut fields = HashMap::new();
213                for (k, v) in &sp.payload {
214                    if let Some(s) = v.as_str() {
215                        fields.insert(k.clone(), s.to_owned());
216                    }
217                }
218                result.insert(key_val.to_owned(), fields);
219            }
220            Ok(result)
221        })
222    }
223
224    fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
225        Box::pin(async { Ok(true) })
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[tokio::test]
234    async fn ensure_collection_and_exists() {
235        let store = InMemoryVectorStore::new();
236        assert!(!store.collection_exists("test").await.unwrap());
237        store.ensure_collection("test", 3).await.unwrap();
238        assert!(store.collection_exists("test").await.unwrap());
239    }
240
241    #[tokio::test]
242    async fn ensure_collection_idempotent() {
243        let store = InMemoryVectorStore::new();
244        store.ensure_collection("test", 3).await.unwrap();
245        store.ensure_collection("test", 3).await.unwrap();
246        assert!(store.collection_exists("test").await.unwrap());
247    }
248
249    #[tokio::test]
250    async fn delete_collection_removes() {
251        let store = InMemoryVectorStore::new();
252        store.ensure_collection("test", 3).await.unwrap();
253        store.delete_collection("test").await.unwrap();
254        assert!(!store.collection_exists("test").await.unwrap());
255    }
256
257    #[tokio::test]
258    async fn upsert_and_search() {
259        let store = InMemoryVectorStore::new();
260        store.ensure_collection("test", 3).await.unwrap();
261
262        let points = vec![
263            VectorPoint {
264                id: "a".into(),
265                vector: vec![1.0, 0.0, 0.0],
266                payload: HashMap::from([("name".into(), serde_json::json!("alpha"))]),
267            },
268            VectorPoint {
269                id: "b".into(),
270                vector: vec![0.0, 1.0, 0.0],
271                payload: HashMap::from([("name".into(), serde_json::json!("beta"))]),
272            },
273        ];
274        store.upsert("test", points).await.unwrap();
275
276        let results = store
277            .search("test", vec![1.0, 0.0, 0.0], 2, None)
278            .await
279            .unwrap();
280        assert_eq!(results.len(), 2);
281        assert_eq!(results[0].id, "a");
282        assert!((results[0].score - 1.0).abs() < f32::EPSILON);
283    }
284
285    #[tokio::test]
286    async fn search_with_filter() {
287        let store = InMemoryVectorStore::new();
288        store.ensure_collection("test", 3).await.unwrap();
289
290        let points = vec![
291            VectorPoint {
292                id: "a".into(),
293                vector: vec![1.0, 0.0, 0.0],
294                payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
295            },
296            VectorPoint {
297                id: "b".into(),
298                vector: vec![0.9, 0.1, 0.0],
299                payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
300            },
301        ];
302        store.upsert("test", points).await.unwrap();
303
304        let filter = VectorFilter {
305            must: vec![crate::vector_store::FieldCondition {
306                field: "role".into(),
307                value: FieldValue::Text("user".into()),
308            }],
309            must_not: vec![],
310        };
311        let results = store
312            .search("test", vec![1.0, 0.0, 0.0], 10, Some(filter))
313            .await
314            .unwrap();
315        assert_eq!(results.len(), 1);
316        assert_eq!(results[0].id, "a");
317    }
318
319    #[tokio::test]
320    async fn delete_by_ids_removes_points() {
321        let store = InMemoryVectorStore::new();
322        store.ensure_collection("test", 3).await.unwrap();
323
324        let points = vec![VectorPoint {
325            id: "a".into(),
326            vector: vec![1.0, 0.0, 0.0],
327            payload: HashMap::new(),
328        }];
329        store.upsert("test", points).await.unwrap();
330        store.delete_by_ids("test", vec!["a".into()]).await.unwrap();
331
332        let results = store
333            .search("test", vec![1.0, 0.0, 0.0], 10, None)
334            .await
335            .unwrap();
336        assert!(results.is_empty());
337    }
338
339    #[tokio::test]
340    async fn scroll_all_extracts_strings() {
341        let store = InMemoryVectorStore::new();
342        store.ensure_collection("test", 3).await.unwrap();
343
344        let points = vec![VectorPoint {
345            id: "a".into(),
346            vector: vec![1.0, 0.0, 0.0],
347            payload: HashMap::from([
348                ("name".into(), serde_json::json!("alpha")),
349                ("desc".into(), serde_json::json!("first")),
350                ("num".into(), serde_json::json!(42)),
351            ]),
352        }];
353        store.upsert("test", points).await.unwrap();
354
355        let result = store.scroll_all("test", "name").await.unwrap();
356        assert_eq!(result.len(), 1);
357        let fields = result.get("alpha").unwrap();
358        assert_eq!(fields.get("desc").unwrap(), "first");
359        assert!(!fields.contains_key("num"));
360    }
361
362    #[test]
363    fn cosine_similarity_import_wired() {
364        // Smoke test: verifies the re-export binding is intact. Edge-case coverage is in math.rs.
365        assert!(!cosine_similarity(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).is_nan());
366    }
367
368    #[tokio::test]
369    async fn default_impl() {
370        let store = InMemoryVectorStore::default();
371        assert!(!store.collection_exists("any").await.unwrap());
372    }
373
374    #[test]
375    fn debug_format() {
376        let store = InMemoryVectorStore::new();
377        let dbg = format!("{store:?}");
378        assert!(dbg.contains("InMemoryVectorStore"));
379    }
380}