1use std::collections::HashMap;
5
6use parking_lot::RwLock;
7
8use crate::vector_store::{
9 BoxFuture, FieldValue, ScoredVectorPoint, VectorFilter, VectorPoint, VectorStore,
10 VectorStoreError,
11};
12
13struct StoredPoint {
14 vector: Vec<f32>,
15 payload: HashMap<String, serde_json::Value>,
16}
17
18struct InMemoryCollection {
19 points: HashMap<String, StoredPoint>,
20}
21
22pub struct InMemoryVectorStore {
23 collections: RwLock<HashMap<String, InMemoryCollection>>,
24}
25
26impl InMemoryVectorStore {
27 #[must_use]
28 pub fn new() -> Self {
29 Self {
30 collections: RwLock::new(HashMap::new()),
31 }
32 }
33}
34
35impl Default for InMemoryVectorStore {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl std::fmt::Debug for InMemoryVectorStore {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.debug_struct("InMemoryVectorStore")
44 .finish_non_exhaustive()
45 }
46}
47
48use zeph_common::math::cosine_similarity;
49
50fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
51 for cond in &filter.must {
52 let Some(val) = payload.get(&cond.field) else {
53 return false;
54 };
55 if !field_matches(val, &cond.value) {
56 return false;
57 }
58 }
59 for cond in &filter.must_not {
60 if let Some(val) = payload.get(&cond.field)
61 && field_matches(val, &cond.value)
62 {
63 return false;
64 }
65 }
66 true
67}
68
69fn field_matches(val: &serde_json::Value, expected: &FieldValue) -> bool {
70 match expected {
71 FieldValue::Integer(i) => val.as_i64() == Some(*i),
72 FieldValue::Text(s) => val.as_str() == Some(s.as_str()),
73 }
74}
75
76impl VectorStore for InMemoryVectorStore {
77 fn ensure_collection(
78 &self,
79 collection: &str,
80 _vector_size: u64,
81 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
82 let collection = collection.to_owned();
83 Box::pin(async move {
84 let mut cols = self.collections.write();
85 cols.entry(collection)
86 .or_insert_with(|| InMemoryCollection {
87 points: HashMap::new(),
88 });
89 Ok(())
90 })
91 }
92
93 fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
94 let collection = collection.to_owned();
95 Box::pin(async move {
96 let cols = self.collections.read();
97 Ok(cols.contains_key(&collection))
98 })
99 }
100
101 fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
102 let collection = collection.to_owned();
103 Box::pin(async move {
104 let mut cols = self.collections.write();
105 cols.remove(&collection);
106 Ok(())
107 })
108 }
109
110 fn upsert(
111 &self,
112 collection: &str,
113 points: Vec<VectorPoint>,
114 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
115 let collection = collection.to_owned();
116 Box::pin(async move {
117 let mut cols = self.collections.write();
118 let col = cols.get_mut(&collection).ok_or_else(|| {
119 VectorStoreError::Upsert(format!("collection {collection} not found"))
120 })?;
121 for p in points {
122 col.points.insert(
123 p.id,
124 StoredPoint {
125 vector: p.vector,
126 payload: p.payload,
127 },
128 );
129 }
130 Ok(())
131 })
132 }
133
134 fn search(
135 &self,
136 collection: &str,
137 vector: Vec<f32>,
138 limit: u64,
139 filter: Option<VectorFilter>,
140 ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
141 let collection = collection.to_owned();
142 Box::pin(async move {
143 let cols = self.collections.read();
144 let col = cols.get(&collection).ok_or_else(|| {
145 VectorStoreError::Search(format!("collection {collection} not found"))
146 })?;
147
148 let empty_filter = VectorFilter::default();
149 let f = filter.as_ref().unwrap_or(&empty_filter);
150
151 let mut scored: Vec<ScoredVectorPoint> = col
152 .points
153 .iter()
154 .filter(|(_, sp)| matches_filter(&sp.payload, f))
155 .map(|(id, sp)| ScoredVectorPoint {
156 id: id.clone(),
157 score: cosine_similarity(&vector, &sp.vector),
158 payload: sp.payload.clone(),
159 })
160 .collect();
161
162 scored.sort_by(|a, b| {
163 b.score
164 .partial_cmp(&a.score)
165 .unwrap_or(std::cmp::Ordering::Equal)
166 });
167 #[expect(clippy::cast_possible_truncation)]
168 scored.truncate(limit as usize);
169 Ok(scored)
170 })
171 }
172
173 fn delete_by_ids(
174 &self,
175 collection: &str,
176 ids: Vec<String>,
177 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
178 let collection = collection.to_owned();
179 Box::pin(async move {
180 if ids.is_empty() {
181 return Ok(());
182 }
183 let mut cols = self.collections.write();
184 let col = cols.get_mut(&collection).ok_or_else(|| {
185 VectorStoreError::Delete(format!("collection {collection} not found"))
186 })?;
187 for id in &ids {
188 col.points.remove(id);
189 }
190 Ok(())
191 })
192 }
193
194 fn scroll_all(
195 &self,
196 collection: &str,
197 key_field: &str,
198 ) -> BoxFuture<'_, Result<HashMap<String, HashMap<String, String>>, VectorStoreError>> {
199 let collection = collection.to_owned();
200 let key_field = key_field.to_owned();
201 Box::pin(async move {
202 let cols = self.collections.read();
203 let col = cols.get(&collection).ok_or_else(|| {
204 VectorStoreError::Scroll(format!("collection {collection} not found"))
205 })?;
206
207 let mut result = HashMap::new();
208 for sp in col.points.values() {
209 let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
210 continue;
211 };
212 let mut fields = HashMap::new();
213 for (k, v) in &sp.payload {
214 if let Some(s) = v.as_str() {
215 fields.insert(k.clone(), s.to_owned());
216 }
217 }
218 result.insert(key_val.to_owned(), fields);
219 }
220 Ok(result)
221 })
222 }
223
224 fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
225 Box::pin(async { Ok(true) })
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[tokio::test]
234 async fn ensure_collection_and_exists() {
235 let store = InMemoryVectorStore::new();
236 assert!(!store.collection_exists("test").await.unwrap());
237 store.ensure_collection("test", 3).await.unwrap();
238 assert!(store.collection_exists("test").await.unwrap());
239 }
240
241 #[tokio::test]
242 async fn ensure_collection_idempotent() {
243 let store = InMemoryVectorStore::new();
244 store.ensure_collection("test", 3).await.unwrap();
245 store.ensure_collection("test", 3).await.unwrap();
246 assert!(store.collection_exists("test").await.unwrap());
247 }
248
249 #[tokio::test]
250 async fn delete_collection_removes() {
251 let store = InMemoryVectorStore::new();
252 store.ensure_collection("test", 3).await.unwrap();
253 store.delete_collection("test").await.unwrap();
254 assert!(!store.collection_exists("test").await.unwrap());
255 }
256
257 #[tokio::test]
258 async fn upsert_and_search() {
259 let store = InMemoryVectorStore::new();
260 store.ensure_collection("test", 3).await.unwrap();
261
262 let points = vec![
263 VectorPoint {
264 id: "a".into(),
265 vector: vec![1.0, 0.0, 0.0],
266 payload: HashMap::from([("name".into(), serde_json::json!("alpha"))]),
267 },
268 VectorPoint {
269 id: "b".into(),
270 vector: vec![0.0, 1.0, 0.0],
271 payload: HashMap::from([("name".into(), serde_json::json!("beta"))]),
272 },
273 ];
274 store.upsert("test", points).await.unwrap();
275
276 let results = store
277 .search("test", vec![1.0, 0.0, 0.0], 2, None)
278 .await
279 .unwrap();
280 assert_eq!(results.len(), 2);
281 assert_eq!(results[0].id, "a");
282 assert!((results[0].score - 1.0).abs() < f32::EPSILON);
283 }
284
285 #[tokio::test]
286 async fn search_with_filter() {
287 let store = InMemoryVectorStore::new();
288 store.ensure_collection("test", 3).await.unwrap();
289
290 let points = vec![
291 VectorPoint {
292 id: "a".into(),
293 vector: vec![1.0, 0.0, 0.0],
294 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
295 },
296 VectorPoint {
297 id: "b".into(),
298 vector: vec![0.9, 0.1, 0.0],
299 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
300 },
301 ];
302 store.upsert("test", points).await.unwrap();
303
304 let filter = VectorFilter {
305 must: vec![crate::vector_store::FieldCondition {
306 field: "role".into(),
307 value: FieldValue::Text("user".into()),
308 }],
309 must_not: vec![],
310 };
311 let results = store
312 .search("test", vec![1.0, 0.0, 0.0], 10, Some(filter))
313 .await
314 .unwrap();
315 assert_eq!(results.len(), 1);
316 assert_eq!(results[0].id, "a");
317 }
318
319 #[tokio::test]
320 async fn delete_by_ids_removes_points() {
321 let store = InMemoryVectorStore::new();
322 store.ensure_collection("test", 3).await.unwrap();
323
324 let points = vec![VectorPoint {
325 id: "a".into(),
326 vector: vec![1.0, 0.0, 0.0],
327 payload: HashMap::new(),
328 }];
329 store.upsert("test", points).await.unwrap();
330 store.delete_by_ids("test", vec!["a".into()]).await.unwrap();
331
332 let results = store
333 .search("test", vec![1.0, 0.0, 0.0], 10, None)
334 .await
335 .unwrap();
336 assert!(results.is_empty());
337 }
338
339 #[tokio::test]
340 async fn scroll_all_extracts_strings() {
341 let store = InMemoryVectorStore::new();
342 store.ensure_collection("test", 3).await.unwrap();
343
344 let points = vec![VectorPoint {
345 id: "a".into(),
346 vector: vec![1.0, 0.0, 0.0],
347 payload: HashMap::from([
348 ("name".into(), serde_json::json!("alpha")),
349 ("desc".into(), serde_json::json!("first")),
350 ("num".into(), serde_json::json!(42)),
351 ]),
352 }];
353 store.upsert("test", points).await.unwrap();
354
355 let result = store.scroll_all("test", "name").await.unwrap();
356 assert_eq!(result.len(), 1);
357 let fields = result.get("alpha").unwrap();
358 assert_eq!(fields.get("desc").unwrap(), "first");
359 assert!(!fields.contains_key("num"));
360 }
361
362 #[test]
363 fn cosine_similarity_import_wired() {
364 assert!(!cosine_similarity(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).is_nan());
366 }
367
368 #[tokio::test]
369 async fn default_impl() {
370 let store = InMemoryVectorStore::default();
371 assert!(!store.collection_exists("any").await.unwrap());
372 }
373
374 #[test]
375 fn debug_format() {
376 let store = InMemoryVectorStore::new();
377 let dbg = format!("{store:?}");
378 assert!(dbg.contains("InMemoryVectorStore"));
379 }
380}