1use std::collections::HashMap;
11
12use parking_lot::RwLock;
13
14use crate::vector_store::{
15 BoxFuture, FieldValue, ScoredVectorPoint, VectorFilter, VectorPoint, VectorStore,
16 VectorStoreError,
17};
18
19struct StoredPoint {
20 vector: Vec<f32>,
21 payload: HashMap<String, serde_json::Value>,
22}
23
24struct InMemoryCollection {
25 points: HashMap<String, StoredPoint>,
26}
27
28pub struct InMemoryVectorStore {
41 collections: RwLock<HashMap<String, InMemoryCollection>>,
42}
43
44impl InMemoryVectorStore {
45 #[must_use]
46 pub fn new() -> Self {
47 Self {
48 collections: RwLock::new(HashMap::new()),
49 }
50 }
51}
52
53impl Default for InMemoryVectorStore {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl std::fmt::Debug for InMemoryVectorStore {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("InMemoryVectorStore")
62 .finish_non_exhaustive()
63 }
64}
65
66use zeph_common::math::cosine_similarity;
67
68fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
69 for cond in &filter.must {
70 let Some(val) = payload.get(&cond.field) else {
71 return false;
72 };
73 if !field_matches(val, &cond.value) {
74 return false;
75 }
76 }
77 for cond in &filter.must_not {
78 if let Some(val) = payload.get(&cond.field)
79 && field_matches(val, &cond.value)
80 {
81 return false;
82 }
83 }
84 true
85}
86
87fn field_matches(val: &serde_json::Value, expected: &FieldValue) -> bool {
88 match expected {
89 FieldValue::Integer(i) => val.as_i64() == Some(*i),
90 FieldValue::Text(s) => val.as_str() == Some(s.as_str()),
91 }
92}
93
94impl VectorStore for InMemoryVectorStore {
95 fn ensure_collection(
96 &self,
97 collection: &str,
98 _vector_size: u64,
99 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
100 let collection = collection.to_owned();
101 Box::pin(async move {
102 let mut cols = self.collections.write();
103 cols.entry(collection)
104 .or_insert_with(|| InMemoryCollection {
105 points: HashMap::new(),
106 });
107 Ok(())
108 })
109 }
110
111 fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
112 let collection = collection.to_owned();
113 Box::pin(async move {
114 let cols = self.collections.read();
115 Ok(cols.contains_key(&collection))
116 })
117 }
118
119 fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
120 let collection = collection.to_owned();
121 Box::pin(async move {
122 let mut cols = self.collections.write();
123 cols.remove(&collection);
124 Ok(())
125 })
126 }
127
128 fn upsert(
129 &self,
130 collection: &str,
131 points: Vec<VectorPoint>,
132 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
133 let collection = collection.to_owned();
134 Box::pin(async move {
135 let mut cols = self.collections.write();
136 let col = cols.get_mut(&collection).ok_or_else(|| {
137 VectorStoreError::Upsert(format!("collection {collection} not found"))
138 })?;
139 for p in points {
140 col.points.insert(
141 p.id,
142 StoredPoint {
143 vector: p.vector,
144 payload: p.payload,
145 },
146 );
147 }
148 Ok(())
149 })
150 }
151
152 fn search(
153 &self,
154 collection: &str,
155 vector: Vec<f32>,
156 limit: u64,
157 filter: Option<VectorFilter>,
158 ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
159 let collection = collection.to_owned();
160 Box::pin(async move {
161 let cols = self.collections.read();
162 let col = cols.get(&collection).ok_or_else(|| {
163 VectorStoreError::Search(format!("collection {collection} not found"))
164 })?;
165
166 let empty_filter = VectorFilter::default();
167 let f = filter.as_ref().unwrap_or(&empty_filter);
168
169 let mut scored: Vec<ScoredVectorPoint> = col
170 .points
171 .iter()
172 .filter(|(_, sp)| matches_filter(&sp.payload, f))
173 .map(|(id, sp)| ScoredVectorPoint {
174 id: id.clone(),
175 score: cosine_similarity(&vector, &sp.vector),
176 payload: sp.payload.clone(),
177 })
178 .collect();
179
180 scored.sort_by(|a, b| {
181 b.score
182 .partial_cmp(&a.score)
183 .unwrap_or(std::cmp::Ordering::Equal)
184 });
185 #[expect(clippy::cast_possible_truncation)]
186 scored.truncate(limit as usize);
187 Ok(scored)
188 })
189 }
190
191 fn delete_by_ids(
192 &self,
193 collection: &str,
194 ids: Vec<String>,
195 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
196 let collection = collection.to_owned();
197 Box::pin(async move {
198 if ids.is_empty() {
199 return Ok(());
200 }
201 let mut cols = self.collections.write();
202 let col = cols.get_mut(&collection).ok_or_else(|| {
203 VectorStoreError::Delete(format!("collection {collection} not found"))
204 })?;
205 for id in &ids {
206 col.points.remove(id);
207 }
208 Ok(())
209 })
210 }
211
212 fn scroll_all(
213 &self,
214 collection: &str,
215 key_field: &str,
216 ) -> BoxFuture<'_, Result<HashMap<String, HashMap<String, String>>, VectorStoreError>> {
217 let collection = collection.to_owned();
218 let key_field = key_field.to_owned();
219 Box::pin(async move {
220 let cols = self.collections.read();
221 let col = cols.get(&collection).ok_or_else(|| {
222 VectorStoreError::Scroll(format!("collection {collection} not found"))
223 })?;
224
225 let mut result = HashMap::new();
226 for sp in col.points.values() {
227 let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
228 continue;
229 };
230 let mut fields = HashMap::new();
231 for (k, v) in &sp.payload {
232 if let Some(s) = v.as_str() {
233 fields.insert(k.clone(), s.to_owned());
234 }
235 }
236 result.insert(key_val.to_owned(), fields);
237 }
238 Ok(result)
239 })
240 }
241
242 fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
243 Box::pin(async { Ok(true) })
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[tokio::test]
252 async fn ensure_collection_and_exists() {
253 let store = InMemoryVectorStore::new();
254 assert!(!store.collection_exists("test").await.unwrap());
255 store.ensure_collection("test", 3).await.unwrap();
256 assert!(store.collection_exists("test").await.unwrap());
257 }
258
259 #[tokio::test]
260 async fn ensure_collection_idempotent() {
261 let store = InMemoryVectorStore::new();
262 store.ensure_collection("test", 3).await.unwrap();
263 store.ensure_collection("test", 3).await.unwrap();
264 assert!(store.collection_exists("test").await.unwrap());
265 }
266
267 #[tokio::test]
268 async fn delete_collection_removes() {
269 let store = InMemoryVectorStore::new();
270 store.ensure_collection("test", 3).await.unwrap();
271 store.delete_collection("test").await.unwrap();
272 assert!(!store.collection_exists("test").await.unwrap());
273 }
274
275 #[tokio::test]
276 async fn upsert_and_search() {
277 let store = InMemoryVectorStore::new();
278 store.ensure_collection("test", 3).await.unwrap();
279
280 let points = vec![
281 VectorPoint {
282 id: "a".into(),
283 vector: vec![1.0, 0.0, 0.0],
284 payload: HashMap::from([("name".into(), serde_json::json!("alpha"))]),
285 },
286 VectorPoint {
287 id: "b".into(),
288 vector: vec![0.0, 1.0, 0.0],
289 payload: HashMap::from([("name".into(), serde_json::json!("beta"))]),
290 },
291 ];
292 store.upsert("test", points).await.unwrap();
293
294 let results = store
295 .search("test", vec![1.0, 0.0, 0.0], 2, None)
296 .await
297 .unwrap();
298 assert_eq!(results.len(), 2);
299 assert_eq!(results[0].id, "a");
300 assert!((results[0].score - 1.0).abs() < f32::EPSILON);
301 }
302
303 #[tokio::test]
304 async fn search_with_filter() {
305 let store = InMemoryVectorStore::new();
306 store.ensure_collection("test", 3).await.unwrap();
307
308 let points = vec![
309 VectorPoint {
310 id: "a".into(),
311 vector: vec![1.0, 0.0, 0.0],
312 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
313 },
314 VectorPoint {
315 id: "b".into(),
316 vector: vec![0.9, 0.1, 0.0],
317 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
318 },
319 ];
320 store.upsert("test", points).await.unwrap();
321
322 let filter = VectorFilter {
323 must: vec![crate::vector_store::FieldCondition {
324 field: "role".into(),
325 value: FieldValue::Text("user".into()),
326 }],
327 must_not: vec![],
328 };
329 let results = store
330 .search("test", vec![1.0, 0.0, 0.0], 10, Some(filter))
331 .await
332 .unwrap();
333 assert_eq!(results.len(), 1);
334 assert_eq!(results[0].id, "a");
335 }
336
337 #[tokio::test]
338 async fn delete_by_ids_removes_points() {
339 let store = InMemoryVectorStore::new();
340 store.ensure_collection("test", 3).await.unwrap();
341
342 let points = vec![VectorPoint {
343 id: "a".into(),
344 vector: vec![1.0, 0.0, 0.0],
345 payload: HashMap::new(),
346 }];
347 store.upsert("test", points).await.unwrap();
348 store.delete_by_ids("test", vec!["a".into()]).await.unwrap();
349
350 let results = store
351 .search("test", vec![1.0, 0.0, 0.0], 10, None)
352 .await
353 .unwrap();
354 assert!(results.is_empty());
355 }
356
357 #[tokio::test]
358 async fn scroll_all_extracts_strings() {
359 let store = InMemoryVectorStore::new();
360 store.ensure_collection("test", 3).await.unwrap();
361
362 let points = vec![VectorPoint {
363 id: "a".into(),
364 vector: vec![1.0, 0.0, 0.0],
365 payload: HashMap::from([
366 ("name".into(), serde_json::json!("alpha")),
367 ("desc".into(), serde_json::json!("first")),
368 ("num".into(), serde_json::json!(42)),
369 ]),
370 }];
371 store.upsert("test", points).await.unwrap();
372
373 let result = store.scroll_all("test", "name").await.unwrap();
374 assert_eq!(result.len(), 1);
375 let fields = result.get("alpha").unwrap();
376 assert_eq!(fields.get("desc").unwrap(), "first");
377 assert!(!fields.contains_key("num"));
378 }
379
380 #[test]
381 fn cosine_similarity_import_wired() {
382 assert!(!cosine_similarity(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).is_nan());
384 }
385
386 #[tokio::test]
387 async fn default_impl() {
388 let store = InMemoryVectorStore::default();
389 assert!(!store.collection_exists("any").await.unwrap());
390 }
391
392 #[test]
393 fn debug_format() {
394 let store = InMemoryVectorStore::new();
395 let dbg = format!("{store:?}");
396 assert!(dbg.contains("InMemoryVectorStore"));
397 }
398}