1use std::collections::HashMap;
5use std::sync::RwLock;
6
7use crate::vector_store::{
8 BoxFuture, FieldValue, ScoredVectorPoint, VectorFilter, VectorPoint, VectorStore,
9 VectorStoreError,
10};
11
12struct StoredPoint {
13 vector: Vec<f32>,
14 payload: HashMap<String, serde_json::Value>,
15}
16
17struct InMemoryCollection {
18 points: HashMap<String, StoredPoint>,
19}
20
21pub struct InMemoryVectorStore {
22 collections: RwLock<HashMap<String, InMemoryCollection>>,
23}
24
25impl InMemoryVectorStore {
26 #[must_use]
27 pub fn new() -> Self {
28 Self {
29 collections: RwLock::new(HashMap::new()),
30 }
31 }
32}
33
34impl Default for InMemoryVectorStore {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl std::fmt::Debug for InMemoryVectorStore {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 f.debug_struct("InMemoryVectorStore")
43 .finish_non_exhaustive()
44 }
45}
46
47use crate::math::cosine_similarity;
48
49fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
50 for cond in &filter.must {
51 let Some(val) = payload.get(&cond.field) else {
52 return false;
53 };
54 if !field_matches(val, &cond.value) {
55 return false;
56 }
57 }
58 for cond in &filter.must_not {
59 if let Some(val) = payload.get(&cond.field)
60 && field_matches(val, &cond.value)
61 {
62 return false;
63 }
64 }
65 true
66}
67
68fn field_matches(val: &serde_json::Value, expected: &FieldValue) -> bool {
69 match expected {
70 FieldValue::Integer(i) => val.as_i64() == Some(*i),
71 FieldValue::Text(s) => val.as_str() == Some(s.as_str()),
72 }
73}
74
75impl VectorStore for InMemoryVectorStore {
76 fn ensure_collection(
77 &self,
78 collection: &str,
79 _vector_size: u64,
80 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
81 let collection = collection.to_owned();
82 Box::pin(async move {
83 let mut cols = self
84 .collections
85 .write()
86 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
87 cols.entry(collection)
88 .or_insert_with(|| InMemoryCollection {
89 points: HashMap::new(),
90 });
91 Ok(())
92 })
93 }
94
95 fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
96 let collection = collection.to_owned();
97 Box::pin(async move {
98 let cols = self
99 .collections
100 .read()
101 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
102 Ok(cols.contains_key(&collection))
103 })
104 }
105
106 fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
107 let collection = collection.to_owned();
108 Box::pin(async move {
109 let mut cols = self
110 .collections
111 .write()
112 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
113 cols.remove(&collection);
114 Ok(())
115 })
116 }
117
118 fn upsert(
119 &self,
120 collection: &str,
121 points: Vec<VectorPoint>,
122 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
123 let collection = collection.to_owned();
124 Box::pin(async move {
125 let mut cols = self
126 .collections
127 .write()
128 .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
129 let col = cols.get_mut(&collection).ok_or_else(|| {
130 VectorStoreError::Upsert(format!("collection {collection} not found"))
131 })?;
132 for p in points {
133 col.points.insert(
134 p.id,
135 StoredPoint {
136 vector: p.vector,
137 payload: p.payload,
138 },
139 );
140 }
141 Ok(())
142 })
143 }
144
145 fn search(
146 &self,
147 collection: &str,
148 vector: Vec<f32>,
149 limit: u64,
150 filter: Option<VectorFilter>,
151 ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
152 let collection = collection.to_owned();
153 Box::pin(async move {
154 let cols = self
155 .collections
156 .read()
157 .map_err(|e| VectorStoreError::Search(e.to_string()))?;
158 let col = cols.get(&collection).ok_or_else(|| {
159 VectorStoreError::Search(format!("collection {collection} not found"))
160 })?;
161
162 let empty_filter = VectorFilter::default();
163 let f = filter.as_ref().unwrap_or(&empty_filter);
164
165 let mut scored: Vec<ScoredVectorPoint> = col
166 .points
167 .iter()
168 .filter(|(_, sp)| matches_filter(&sp.payload, f))
169 .map(|(id, sp)| ScoredVectorPoint {
170 id: id.clone(),
171 score: cosine_similarity(&vector, &sp.vector),
172 payload: sp.payload.clone(),
173 })
174 .collect();
175
176 scored.sort_by(|a, b| {
177 b.score
178 .partial_cmp(&a.score)
179 .unwrap_or(std::cmp::Ordering::Equal)
180 });
181 #[expect(clippy::cast_possible_truncation)]
182 scored.truncate(limit as usize);
183 Ok(scored)
184 })
185 }
186
187 fn delete_by_ids(
188 &self,
189 collection: &str,
190 ids: Vec<String>,
191 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
192 let collection = collection.to_owned();
193 Box::pin(async move {
194 if ids.is_empty() {
195 return Ok(());
196 }
197 let mut cols = self
198 .collections
199 .write()
200 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
201 let col = cols.get_mut(&collection).ok_or_else(|| {
202 VectorStoreError::Delete(format!("collection {collection} not found"))
203 })?;
204 for id in &ids {
205 col.points.remove(id);
206 }
207 Ok(())
208 })
209 }
210
211 fn scroll_all(
212 &self,
213 collection: &str,
214 key_field: &str,
215 ) -> BoxFuture<'_, Result<HashMap<String, HashMap<String, String>>, VectorStoreError>> {
216 let collection = collection.to_owned();
217 let key_field = key_field.to_owned();
218 Box::pin(async move {
219 let cols = self
220 .collections
221 .read()
222 .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
223 let col = cols.get(&collection).ok_or_else(|| {
224 VectorStoreError::Scroll(format!("collection {collection} not found"))
225 })?;
226
227 let mut result = HashMap::new();
228 for sp in col.points.values() {
229 let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
230 continue;
231 };
232 let mut fields = HashMap::new();
233 for (k, v) in &sp.payload {
234 if let Some(s) = v.as_str() {
235 fields.insert(k.clone(), s.to_owned());
236 }
237 }
238 result.insert(key_val.to_owned(), fields);
239 }
240 Ok(result)
241 })
242 }
243
244 fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
245 Box::pin(async { Ok(true) })
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[tokio::test]
254 async fn ensure_collection_and_exists() {
255 let store = InMemoryVectorStore::new();
256 assert!(!store.collection_exists("test").await.unwrap());
257 store.ensure_collection("test", 3).await.unwrap();
258 assert!(store.collection_exists("test").await.unwrap());
259 }
260
261 #[tokio::test]
262 async fn ensure_collection_idempotent() {
263 let store = InMemoryVectorStore::new();
264 store.ensure_collection("test", 3).await.unwrap();
265 store.ensure_collection("test", 3).await.unwrap();
266 assert!(store.collection_exists("test").await.unwrap());
267 }
268
269 #[tokio::test]
270 async fn delete_collection_removes() {
271 let store = InMemoryVectorStore::new();
272 store.ensure_collection("test", 3).await.unwrap();
273 store.delete_collection("test").await.unwrap();
274 assert!(!store.collection_exists("test").await.unwrap());
275 }
276
277 #[tokio::test]
278 async fn upsert_and_search() {
279 let store = InMemoryVectorStore::new();
280 store.ensure_collection("test", 3).await.unwrap();
281
282 let points = vec![
283 VectorPoint {
284 id: "a".into(),
285 vector: vec![1.0, 0.0, 0.0],
286 payload: HashMap::from([("name".into(), serde_json::json!("alpha"))]),
287 },
288 VectorPoint {
289 id: "b".into(),
290 vector: vec![0.0, 1.0, 0.0],
291 payload: HashMap::from([("name".into(), serde_json::json!("beta"))]),
292 },
293 ];
294 store.upsert("test", points).await.unwrap();
295
296 let results = store
297 .search("test", vec![1.0, 0.0, 0.0], 2, None)
298 .await
299 .unwrap();
300 assert_eq!(results.len(), 2);
301 assert_eq!(results[0].id, "a");
302 assert!((results[0].score - 1.0).abs() < f32::EPSILON);
303 }
304
305 #[tokio::test]
306 async fn search_with_filter() {
307 let store = InMemoryVectorStore::new();
308 store.ensure_collection("test", 3).await.unwrap();
309
310 let points = vec![
311 VectorPoint {
312 id: "a".into(),
313 vector: vec![1.0, 0.0, 0.0],
314 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
315 },
316 VectorPoint {
317 id: "b".into(),
318 vector: vec![0.9, 0.1, 0.0],
319 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
320 },
321 ];
322 store.upsert("test", points).await.unwrap();
323
324 let filter = VectorFilter {
325 must: vec![crate::vector_store::FieldCondition {
326 field: "role".into(),
327 value: FieldValue::Text("user".into()),
328 }],
329 must_not: vec![],
330 };
331 let results = store
332 .search("test", vec![1.0, 0.0, 0.0], 10, Some(filter))
333 .await
334 .unwrap();
335 assert_eq!(results.len(), 1);
336 assert_eq!(results[0].id, "a");
337 }
338
339 #[tokio::test]
340 async fn delete_by_ids_removes_points() {
341 let store = InMemoryVectorStore::new();
342 store.ensure_collection("test", 3).await.unwrap();
343
344 let points = vec![VectorPoint {
345 id: "a".into(),
346 vector: vec![1.0, 0.0, 0.0],
347 payload: HashMap::new(),
348 }];
349 store.upsert("test", points).await.unwrap();
350 store.delete_by_ids("test", vec!["a".into()]).await.unwrap();
351
352 let results = store
353 .search("test", vec![1.0, 0.0, 0.0], 10, None)
354 .await
355 .unwrap();
356 assert!(results.is_empty());
357 }
358
359 #[tokio::test]
360 async fn scroll_all_extracts_strings() {
361 let store = InMemoryVectorStore::new();
362 store.ensure_collection("test", 3).await.unwrap();
363
364 let points = vec![VectorPoint {
365 id: "a".into(),
366 vector: vec![1.0, 0.0, 0.0],
367 payload: HashMap::from([
368 ("name".into(), serde_json::json!("alpha")),
369 ("desc".into(), serde_json::json!("first")),
370 ("num".into(), serde_json::json!(42)),
371 ]),
372 }];
373 store.upsert("test", points).await.unwrap();
374
375 let result = store.scroll_all("test", "name").await.unwrap();
376 assert_eq!(result.len(), 1);
377 let fields = result.get("alpha").unwrap();
378 assert_eq!(fields.get("desc").unwrap(), "first");
379 assert!(!fields.contains_key("num"));
380 }
381
382 #[test]
383 fn cosine_similarity_import_wired() {
384 assert!(!cosine_similarity(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).is_nan());
386 }
387
388 #[tokio::test]
389 async fn default_impl() {
390 let store = InMemoryVectorStore::default();
391 assert!(!store.collection_exists("any").await.unwrap());
392 }
393
394 #[test]
395 fn debug_format() {
396 let store = InMemoryVectorStore::new();
397 let dbg = format!("{store:?}");
398 assert!(dbg.contains("InMemoryVectorStore"));
399 }
400}