Skip to main content

zeph_memory/
in_memory_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::RwLock;
8
9use crate::vector_store::{
10    FieldValue, ScoredVectorPoint, VectorFilter, VectorPoint, VectorStore, VectorStoreError,
11};
12
13type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
14
15struct StoredPoint {
16    vector: Vec<f32>,
17    payload: HashMap<String, serde_json::Value>,
18}
19
20struct InMemoryCollection {
21    points: HashMap<String, StoredPoint>,
22}
23
24pub struct InMemoryVectorStore {
25    collections: RwLock<HashMap<String, InMemoryCollection>>,
26}
27
28impl InMemoryVectorStore {
29    #[must_use]
30    pub fn new() -> Self {
31        Self {
32            collections: RwLock::new(HashMap::new()),
33        }
34    }
35}
36
37impl Default for InMemoryVectorStore {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl std::fmt::Debug for InMemoryVectorStore {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct("InMemoryVectorStore")
46            .finish_non_exhaustive()
47    }
48}
49
50fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
51    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
52    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
53    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
54    if norm_a == 0.0 || norm_b == 0.0 {
55        return 0.0;
56    }
57    dot / (norm_a * norm_b)
58}
59
60fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
61    for cond in &filter.must {
62        let Some(val) = payload.get(&cond.field) else {
63            return false;
64        };
65        if !field_matches(val, &cond.value) {
66            return false;
67        }
68    }
69    for cond in &filter.must_not {
70        if let Some(val) = payload.get(&cond.field)
71            && field_matches(val, &cond.value)
72        {
73            return false;
74        }
75    }
76    true
77}
78
79fn field_matches(val: &serde_json::Value, expected: &FieldValue) -> bool {
80    match expected {
81        FieldValue::Integer(i) => val.as_i64() == Some(*i),
82        FieldValue::Text(s) => val.as_str() == Some(s.as_str()),
83    }
84}
85
86impl VectorStore for InMemoryVectorStore {
87    fn ensure_collection(
88        &self,
89        collection: &str,
90        _vector_size: u64,
91    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
92        let collection = collection.to_owned();
93        Box::pin(async move {
94            let mut cols = self
95                .collections
96                .write()
97                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
98            cols.entry(collection)
99                .or_insert_with(|| InMemoryCollection {
100                    points: HashMap::new(),
101                });
102            Ok(())
103        })
104    }
105
106    fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
107        let collection = collection.to_owned();
108        Box::pin(async move {
109            let cols = self
110                .collections
111                .read()
112                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
113            Ok(cols.contains_key(&collection))
114        })
115    }
116
117    fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
118        let collection = collection.to_owned();
119        Box::pin(async move {
120            let mut cols = self
121                .collections
122                .write()
123                .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
124            cols.remove(&collection);
125            Ok(())
126        })
127    }
128
129    fn upsert(
130        &self,
131        collection: &str,
132        points: Vec<VectorPoint>,
133    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
134        let collection = collection.to_owned();
135        Box::pin(async move {
136            let mut cols = self
137                .collections
138                .write()
139                .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
140            let col = cols.get_mut(&collection).ok_or_else(|| {
141                VectorStoreError::Upsert(format!("collection {collection} not found"))
142            })?;
143            for p in points {
144                col.points.insert(
145                    p.id,
146                    StoredPoint {
147                        vector: p.vector,
148                        payload: p.payload,
149                    },
150                );
151            }
152            Ok(())
153        })
154    }
155
156    fn search(
157        &self,
158        collection: &str,
159        vector: Vec<f32>,
160        limit: u64,
161        filter: Option<VectorFilter>,
162    ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
163        let collection = collection.to_owned();
164        Box::pin(async move {
165            let cols = self
166                .collections
167                .read()
168                .map_err(|e| VectorStoreError::Search(e.to_string()))?;
169            let col = cols.get(&collection).ok_or_else(|| {
170                VectorStoreError::Search(format!("collection {collection} not found"))
171            })?;
172
173            let empty_filter = VectorFilter::default();
174            let f = filter.as_ref().unwrap_or(&empty_filter);
175
176            let mut scored: Vec<ScoredVectorPoint> = col
177                .points
178                .iter()
179                .filter(|(_, sp)| matches_filter(&sp.payload, f))
180                .map(|(id, sp)| ScoredVectorPoint {
181                    id: id.clone(),
182                    score: cosine_similarity(&vector, &sp.vector),
183                    payload: sp.payload.clone(),
184                })
185                .collect();
186
187            scored.sort_by(|a, b| {
188                b.score
189                    .partial_cmp(&a.score)
190                    .unwrap_or(std::cmp::Ordering::Equal)
191            });
192            #[expect(clippy::cast_possible_truncation)]
193            scored.truncate(limit as usize);
194            Ok(scored)
195        })
196    }
197
198    fn delete_by_ids(
199        &self,
200        collection: &str,
201        ids: Vec<String>,
202    ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
203        let collection = collection.to_owned();
204        Box::pin(async move {
205            if ids.is_empty() {
206                return Ok(());
207            }
208            let mut cols = self
209                .collections
210                .write()
211                .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
212            let col = cols.get_mut(&collection).ok_or_else(|| {
213                VectorStoreError::Delete(format!("collection {collection} not found"))
214            })?;
215            for id in &ids {
216                col.points.remove(id);
217            }
218            Ok(())
219        })
220    }
221
222    fn scroll_all(
223        &self,
224        collection: &str,
225        key_field: &str,
226    ) -> BoxFuture<'_, Result<HashMap<String, HashMap<String, String>>, VectorStoreError>> {
227        let collection = collection.to_owned();
228        let key_field = key_field.to_owned();
229        Box::pin(async move {
230            let cols = self
231                .collections
232                .read()
233                .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
234            let col = cols.get(&collection).ok_or_else(|| {
235                VectorStoreError::Scroll(format!("collection {collection} not found"))
236            })?;
237
238            let mut result = HashMap::new();
239            for sp in col.points.values() {
240                let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
241                    continue;
242                };
243                let mut fields = HashMap::new();
244                for (k, v) in &sp.payload {
245                    if let Some(s) = v.as_str() {
246                        fields.insert(k.clone(), s.to_owned());
247                    }
248                }
249                result.insert(key_val.to_owned(), fields);
250            }
251            Ok(result)
252        })
253    }
254
255    fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
256        Box::pin(async { Ok(true) })
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[tokio::test]
265    async fn ensure_collection_and_exists() {
266        let store = InMemoryVectorStore::new();
267        assert!(!store.collection_exists("test").await.unwrap());
268        store.ensure_collection("test", 3).await.unwrap();
269        assert!(store.collection_exists("test").await.unwrap());
270    }
271
272    #[tokio::test]
273    async fn ensure_collection_idempotent() {
274        let store = InMemoryVectorStore::new();
275        store.ensure_collection("test", 3).await.unwrap();
276        store.ensure_collection("test", 3).await.unwrap();
277        assert!(store.collection_exists("test").await.unwrap());
278    }
279
280    #[tokio::test]
281    async fn delete_collection_removes() {
282        let store = InMemoryVectorStore::new();
283        store.ensure_collection("test", 3).await.unwrap();
284        store.delete_collection("test").await.unwrap();
285        assert!(!store.collection_exists("test").await.unwrap());
286    }
287
288    #[tokio::test]
289    async fn upsert_and_search() {
290        let store = InMemoryVectorStore::new();
291        store.ensure_collection("test", 3).await.unwrap();
292
293        let points = vec![
294            VectorPoint {
295                id: "a".into(),
296                vector: vec![1.0, 0.0, 0.0],
297                payload: HashMap::from([("name".into(), serde_json::json!("alpha"))]),
298            },
299            VectorPoint {
300                id: "b".into(),
301                vector: vec![0.0, 1.0, 0.0],
302                payload: HashMap::from([("name".into(), serde_json::json!("beta"))]),
303            },
304        ];
305        store.upsert("test", points).await.unwrap();
306
307        let results = store
308            .search("test", vec![1.0, 0.0, 0.0], 2, None)
309            .await
310            .unwrap();
311        assert_eq!(results.len(), 2);
312        assert_eq!(results[0].id, "a");
313        assert!((results[0].score - 1.0).abs() < f32::EPSILON);
314    }
315
316    #[tokio::test]
317    async fn search_with_filter() {
318        let store = InMemoryVectorStore::new();
319        store.ensure_collection("test", 3).await.unwrap();
320
321        let points = vec![
322            VectorPoint {
323                id: "a".into(),
324                vector: vec![1.0, 0.0, 0.0],
325                payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
326            },
327            VectorPoint {
328                id: "b".into(),
329                vector: vec![0.9, 0.1, 0.0],
330                payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
331            },
332        ];
333        store.upsert("test", points).await.unwrap();
334
335        let filter = VectorFilter {
336            must: vec![crate::vector_store::FieldCondition {
337                field: "role".into(),
338                value: FieldValue::Text("user".into()),
339            }],
340            must_not: vec![],
341        };
342        let results = store
343            .search("test", vec![1.0, 0.0, 0.0], 10, Some(filter))
344            .await
345            .unwrap();
346        assert_eq!(results.len(), 1);
347        assert_eq!(results[0].id, "a");
348    }
349
350    #[tokio::test]
351    async fn delete_by_ids_removes_points() {
352        let store = InMemoryVectorStore::new();
353        store.ensure_collection("test", 3).await.unwrap();
354
355        let points = vec![VectorPoint {
356            id: "a".into(),
357            vector: vec![1.0, 0.0, 0.0],
358            payload: HashMap::new(),
359        }];
360        store.upsert("test", points).await.unwrap();
361        store.delete_by_ids("test", vec!["a".into()]).await.unwrap();
362
363        let results = store
364            .search("test", vec![1.0, 0.0, 0.0], 10, None)
365            .await
366            .unwrap();
367        assert!(results.is_empty());
368    }
369
370    #[tokio::test]
371    async fn scroll_all_extracts_strings() {
372        let store = InMemoryVectorStore::new();
373        store.ensure_collection("test", 3).await.unwrap();
374
375        let points = vec![VectorPoint {
376            id: "a".into(),
377            vector: vec![1.0, 0.0, 0.0],
378            payload: HashMap::from([
379                ("name".into(), serde_json::json!("alpha")),
380                ("desc".into(), serde_json::json!("first")),
381                ("num".into(), serde_json::json!(42)),
382            ]),
383        }];
384        store.upsert("test", points).await.unwrap();
385
386        let result = store.scroll_all("test", "name").await.unwrap();
387        assert_eq!(result.len(), 1);
388        let fields = result.get("alpha").unwrap();
389        assert_eq!(fields.get("desc").unwrap(), "first");
390        assert!(!fields.contains_key("num"));
391    }
392
393    #[test]
394    fn cosine_similarity_orthogonal() {
395        let a = vec![1.0, 0.0, 0.0];
396        let b = vec![0.0, 1.0, 0.0];
397        assert!((cosine_similarity(&a, &b)).abs() < f32::EPSILON);
398    }
399
400    #[tokio::test]
401    async fn default_impl() {
402        let store = InMemoryVectorStore::default();
403        assert!(!store.collection_exists("any").await.unwrap());
404    }
405
406    #[test]
407    fn debug_format() {
408        let store = InMemoryVectorStore::new();
409        let dbg = format!("{store:?}");
410        assert!(dbg.contains("InMemoryVectorStore"));
411    }
412}