1use std::collections::HashMap;
11
12use parking_lot::RwLock;
13
14use crate::vector_store::{
15 BoxFuture, FieldValue, ScoredVectorPoint, ScrollWithIdsResult, VectorFilter, VectorPoint,
16 VectorStore, VectorStoreError,
17};
18
19struct StoredPoint {
20 vector: Vec<f32>,
21 payload: HashMap<String, serde_json::Value>,
22}
23
24struct InMemoryCollection {
25 points: HashMap<String, StoredPoint>,
26}
27
28pub struct InMemoryVectorStore {
41 collections: RwLock<HashMap<String, InMemoryCollection>>,
42}
43
44impl InMemoryVectorStore {
45 #[must_use]
46 pub fn new() -> Self {
47 Self {
48 collections: RwLock::new(HashMap::new()),
49 }
50 }
51}
52
53impl Default for InMemoryVectorStore {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl std::fmt::Debug for InMemoryVectorStore {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("InMemoryVectorStore")
62 .finish_non_exhaustive()
63 }
64}
65
66use zeph_common::math::cosine_similarity;
67
68fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
69 for cond in &filter.must {
70 let Some(val) = payload.get(&cond.field) else {
71 return false;
72 };
73 if !field_matches(val, &cond.value) {
74 return false;
75 }
76 }
77 for cond in &filter.must_not {
78 if let Some(val) = payload.get(&cond.field)
79 && field_matches(val, &cond.value)
80 {
81 return false;
82 }
83 }
84 true
85}
86
87fn field_matches(val: &serde_json::Value, expected: &FieldValue) -> bool {
88 match expected {
89 FieldValue::Integer(i) => val.as_i64() == Some(*i),
90 FieldValue::Text(s) => val.as_str() == Some(s.as_str()),
91 }
92}
93
94impl VectorStore for InMemoryVectorStore {
95 fn ensure_collection(
96 &self,
97 collection: &str,
98 _vector_size: u64,
99 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
100 let collection = collection.to_owned();
101 Box::pin(async move {
102 let mut cols = self.collections.write();
103 cols.entry(collection)
104 .or_insert_with(|| InMemoryCollection {
105 points: HashMap::new(),
106 });
107 Ok(())
108 })
109 }
110
111 fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
112 let collection = collection.to_owned();
113 Box::pin(async move {
114 let cols = self.collections.read();
115 Ok(cols.contains_key(&collection))
116 })
117 }
118
119 fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
120 let collection = collection.to_owned();
121 Box::pin(async move {
122 let mut cols = self.collections.write();
123 cols.remove(&collection);
124 Ok(())
125 })
126 }
127
128 fn upsert(
129 &self,
130 collection: &str,
131 points: Vec<VectorPoint>,
132 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
133 let collection = collection.to_owned();
134 Box::pin(async move {
135 let mut cols = self.collections.write();
136 let col = cols.get_mut(&collection).ok_or_else(|| {
137 VectorStoreError::Upsert(format!("collection {collection} not found"))
138 })?;
139 for p in points {
140 col.points.insert(
141 p.id,
142 StoredPoint {
143 vector: p.vector,
144 payload: p.payload,
145 },
146 );
147 }
148 Ok(())
149 })
150 }
151
152 fn search(
153 &self,
154 collection: &str,
155 vector: Vec<f32>,
156 limit: u64,
157 filter: Option<VectorFilter>,
158 ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
159 let collection = collection.to_owned();
160 Box::pin(async move {
161 let cols = self.collections.read();
162 let col = cols.get(&collection).ok_or_else(|| {
163 VectorStoreError::Search(format!("collection {collection} not found"))
164 })?;
165
166 let empty_filter = VectorFilter::default();
167 let f = filter.as_ref().unwrap_or(&empty_filter);
168
169 let mut scored: Vec<ScoredVectorPoint> = col
170 .points
171 .iter()
172 .filter(|(_, sp)| matches_filter(&sp.payload, f))
173 .map(|(id, sp)| ScoredVectorPoint {
174 id: id.clone(),
175 score: cosine_similarity(&vector, &sp.vector),
176 payload: sp.payload.clone(),
177 })
178 .collect();
179
180 scored.sort_by(|a, b| {
181 b.score
182 .partial_cmp(&a.score)
183 .unwrap_or(std::cmp::Ordering::Equal)
184 });
185 #[expect(clippy::cast_possible_truncation)]
186 scored.truncate(limit as usize);
187 Ok(scored)
188 })
189 }
190
191 fn delete_by_ids(
192 &self,
193 collection: &str,
194 ids: Vec<String>,
195 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
196 let collection = collection.to_owned();
197 Box::pin(async move {
198 if ids.is_empty() {
199 return Ok(());
200 }
201 let mut cols = self.collections.write();
202 let col = cols.get_mut(&collection).ok_or_else(|| {
203 VectorStoreError::Delete(format!("collection {collection} not found"))
204 })?;
205 for id in &ids {
206 col.points.remove(id);
207 }
208 Ok(())
209 })
210 }
211
212 fn scroll_all(
213 &self,
214 collection: &str,
215 key_field: &str,
216 ) -> BoxFuture<'_, Result<HashMap<String, HashMap<String, String>>, VectorStoreError>> {
217 let collection = collection.to_owned();
218 let key_field = key_field.to_owned();
219 Box::pin(async move {
220 let cols = self.collections.read();
221 let col = cols.get(&collection).ok_or_else(|| {
222 VectorStoreError::Scroll(format!("collection {collection} not found"))
223 })?;
224
225 let mut result = HashMap::new();
226 for sp in col.points.values() {
227 let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
228 continue;
229 };
230 let mut fields = HashMap::new();
231 for (k, v) in &sp.payload {
232 if let Some(s) = v.as_str() {
233 fields.insert(k.clone(), s.to_owned());
234 }
235 }
236 result.insert(key_val.to_owned(), fields);
237 }
238 Ok(result)
239 })
240 }
241
242 fn scroll_all_with_point_ids(
243 &self,
244 collection: &str,
245 key_field: &str,
246 ) -> BoxFuture<'_, Result<ScrollWithIdsResult, VectorStoreError>> {
247 let collection = collection.to_owned();
248 let key_field = key_field.to_owned();
249 Box::pin(async move {
250 let cols = self.collections.read();
251 let col = cols.get(&collection).ok_or_else(|| {
252 VectorStoreError::Scroll(format!("collection {collection} not found"))
253 })?;
254
255 let mut result = Vec::new();
256 for (point_id, sp) in &col.points {
257 let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
258 continue;
259 };
260 let mut fields = HashMap::new();
261 for (k, v) in &sp.payload {
262 if let Some(s) = v.as_str() {
263 fields.insert(k.clone(), s.to_owned());
264 }
265 }
266 fields.insert(key_field.clone(), key_val.to_owned());
268 result.push((point_id.clone(), fields));
269 }
270 Ok(result)
271 })
272 }
273
274 fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
275 Box::pin(async { Ok(true) })
276 }
277
278 fn get_points(
279 &self,
280 collection: &str,
281 ids: Vec<String>,
282 ) -> BoxFuture<'_, Result<Vec<VectorPoint>, VectorStoreError>> {
283 let collection = collection.to_owned();
284 Box::pin(async move {
285 let cols = self.collections.read();
286 let col = cols.get(&collection).ok_or_else(|| {
287 VectorStoreError::Unsupported(format!("collection {collection} not found"))
288 })?;
289 let points = ids
290 .into_iter()
291 .filter_map(|id| {
292 col.points.get(&id).map(|sp| VectorPoint {
293 id: id.clone(),
294 vector: sp.vector.clone(),
295 payload: sp.payload.clone(),
296 })
297 })
298 .collect();
299 Ok(points)
300 })
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[tokio::test]
309 async fn ensure_collection_and_exists() {
310 let store = InMemoryVectorStore::new();
311 assert!(!store.collection_exists("test").await.unwrap());
312 store.ensure_collection("test", 3).await.unwrap();
313 assert!(store.collection_exists("test").await.unwrap());
314 }
315
316 #[tokio::test]
317 async fn ensure_collection_idempotent() {
318 let store = InMemoryVectorStore::new();
319 store.ensure_collection("test", 3).await.unwrap();
320 store.ensure_collection("test", 3).await.unwrap();
321 assert!(store.collection_exists("test").await.unwrap());
322 }
323
324 #[tokio::test]
325 async fn delete_collection_removes() {
326 let store = InMemoryVectorStore::new();
327 store.ensure_collection("test", 3).await.unwrap();
328 store.delete_collection("test").await.unwrap();
329 assert!(!store.collection_exists("test").await.unwrap());
330 }
331
332 #[tokio::test]
333 async fn upsert_and_search() {
334 let store = InMemoryVectorStore::new();
335 store.ensure_collection("test", 3).await.unwrap();
336
337 let points = vec![
338 VectorPoint {
339 id: "a".into(),
340 vector: vec![1.0, 0.0, 0.0],
341 payload: HashMap::from([("name".into(), serde_json::json!("alpha"))]),
342 },
343 VectorPoint {
344 id: "b".into(),
345 vector: vec![0.0, 1.0, 0.0],
346 payload: HashMap::from([("name".into(), serde_json::json!("beta"))]),
347 },
348 ];
349 store.upsert("test", points).await.unwrap();
350
351 let results = store
352 .search("test", vec![1.0, 0.0, 0.0], 2, None)
353 .await
354 .unwrap();
355 assert_eq!(results.len(), 2);
356 assert_eq!(results[0].id, "a");
357 assert!((results[0].score - 1.0).abs() < f32::EPSILON);
358 }
359
360 #[tokio::test]
361 async fn search_with_filter() {
362 let store = InMemoryVectorStore::new();
363 store.ensure_collection("test", 3).await.unwrap();
364
365 let points = vec![
366 VectorPoint {
367 id: "a".into(),
368 vector: vec![1.0, 0.0, 0.0],
369 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
370 },
371 VectorPoint {
372 id: "b".into(),
373 vector: vec![0.9, 0.1, 0.0],
374 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
375 },
376 ];
377 store.upsert("test", points).await.unwrap();
378
379 let filter = VectorFilter {
380 must: vec![crate::vector_store::FieldCondition {
381 field: "role".into(),
382 value: FieldValue::Text("user".into()),
383 }],
384 must_not: vec![],
385 };
386 let results = store
387 .search("test", vec![1.0, 0.0, 0.0], 10, Some(filter))
388 .await
389 .unwrap();
390 assert_eq!(results.len(), 1);
391 assert_eq!(results[0].id, "a");
392 }
393
394 #[tokio::test]
395 async fn delete_by_ids_removes_points() {
396 let store = InMemoryVectorStore::new();
397 store.ensure_collection("test", 3).await.unwrap();
398
399 let points = vec![VectorPoint {
400 id: "a".into(),
401 vector: vec![1.0, 0.0, 0.0],
402 payload: HashMap::new(),
403 }];
404 store.upsert("test", points).await.unwrap();
405 store.delete_by_ids("test", vec!["a".into()]).await.unwrap();
406
407 let results = store
408 .search("test", vec![1.0, 0.0, 0.0], 10, None)
409 .await
410 .unwrap();
411 assert!(results.is_empty());
412 }
413
414 #[tokio::test]
415 async fn scroll_all_extracts_strings() {
416 let store = InMemoryVectorStore::new();
417 store.ensure_collection("test", 3).await.unwrap();
418
419 let points = vec![VectorPoint {
420 id: "a".into(),
421 vector: vec![1.0, 0.0, 0.0],
422 payload: HashMap::from([
423 ("name".into(), serde_json::json!("alpha")),
424 ("desc".into(), serde_json::json!("first")),
425 ("num".into(), serde_json::json!(42)),
426 ]),
427 }];
428 store.upsert("test", points).await.unwrap();
429
430 let result = store.scroll_all("test", "name").await.unwrap();
431 assert_eq!(result.len(), 1);
432 let fields = result.get("alpha").unwrap();
433 assert_eq!(fields.get("desc").unwrap(), "first");
434 assert!(!fields.contains_key("num"));
435 }
436
437 #[tokio::test]
438 async fn scroll_all_with_point_ids_returns_point_id() {
439 let store = InMemoryVectorStore::new();
440 store.ensure_collection("test", 3).await.unwrap();
441
442 let points = vec![
443 VectorPoint {
444 id: "pid-1".into(),
445 vector: vec![1.0, 0.0, 0.0],
446 payload: HashMap::from([
447 ("entity_id_str".into(), serde_json::json!("42")),
448 ("name".into(), serde_json::json!("Alpha")),
449 ("count".into(), serde_json::json!(7)), ]),
451 },
452 VectorPoint {
453 id: "pid-2".into(),
454 vector: vec![0.0, 1.0, 0.0],
455 payload: HashMap::from([("name".into(), serde_json::json!("Beta"))]),
457 },
458 ];
459 store.upsert("test", points).await.unwrap();
460
461 let result = store
462 .scroll_all_with_point_ids("test", "entity_id_str")
463 .await
464 .unwrap();
465
466 assert_eq!(
467 result.len(),
468 1,
469 "only the point with key_field should appear"
470 );
471 let (point_id, fields) = &result[0];
472 assert_eq!(point_id, "pid-1");
473 assert_eq!(fields.get("entity_id_str").map(String::as_str), Some("42"));
474 assert_eq!(fields.get("name").map(String::as_str), Some("Alpha"));
475 assert!(!fields.contains_key("count"));
477 }
478
479 #[test]
480 fn cosine_similarity_import_wired() {
481 assert!(!cosine_similarity(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).is_nan());
483 }
484
485 #[tokio::test]
486 async fn default_impl() {
487 let store = InMemoryVectorStore::default();
488 assert!(!store.collection_exists("any").await.unwrap());
489 }
490
491 #[test]
492 fn debug_format() {
493 let store = InMemoryVectorStore::new();
494 let dbg = format!("{store:?}");
495 assert!(dbg.contains("InMemoryVectorStore"));
496 }
497}