1use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::RwLock;
8
9use crate::vector_store::{
10 FieldValue, ScoredVectorPoint, VectorFilter, VectorPoint, VectorStore, VectorStoreError,
11};
12
13type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
14
15struct StoredPoint {
16 vector: Vec<f32>,
17 payload: HashMap<String, serde_json::Value>,
18}
19
20struct InMemoryCollection {
21 points: HashMap<String, StoredPoint>,
22}
23
24pub struct InMemoryVectorStore {
25 collections: RwLock<HashMap<String, InMemoryCollection>>,
26}
27
28impl InMemoryVectorStore {
29 #[must_use]
30 pub fn new() -> Self {
31 Self {
32 collections: RwLock::new(HashMap::new()),
33 }
34 }
35}
36
37impl Default for InMemoryVectorStore {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl std::fmt::Debug for InMemoryVectorStore {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.debug_struct("InMemoryVectorStore")
46 .finish_non_exhaustive()
47 }
48}
49
50use crate::math::cosine_similarity;
51
52fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
53 for cond in &filter.must {
54 let Some(val) = payload.get(&cond.field) else {
55 return false;
56 };
57 if !field_matches(val, &cond.value) {
58 return false;
59 }
60 }
61 for cond in &filter.must_not {
62 if let Some(val) = payload.get(&cond.field)
63 && field_matches(val, &cond.value)
64 {
65 return false;
66 }
67 }
68 true
69}
70
71fn field_matches(val: &serde_json::Value, expected: &FieldValue) -> bool {
72 match expected {
73 FieldValue::Integer(i) => val.as_i64() == Some(*i),
74 FieldValue::Text(s) => val.as_str() == Some(s.as_str()),
75 }
76}
77
78impl VectorStore for InMemoryVectorStore {
79 fn ensure_collection(
80 &self,
81 collection: &str,
82 _vector_size: u64,
83 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
84 let collection = collection.to_owned();
85 Box::pin(async move {
86 let mut cols = self
87 .collections
88 .write()
89 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
90 cols.entry(collection)
91 .or_insert_with(|| InMemoryCollection {
92 points: HashMap::new(),
93 });
94 Ok(())
95 })
96 }
97
98 fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
99 let collection = collection.to_owned();
100 Box::pin(async move {
101 let cols = self
102 .collections
103 .read()
104 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
105 Ok(cols.contains_key(&collection))
106 })
107 }
108
109 fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
110 let collection = collection.to_owned();
111 Box::pin(async move {
112 let mut cols = self
113 .collections
114 .write()
115 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
116 cols.remove(&collection);
117 Ok(())
118 })
119 }
120
121 fn upsert(
122 &self,
123 collection: &str,
124 points: Vec<VectorPoint>,
125 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
126 let collection = collection.to_owned();
127 Box::pin(async move {
128 let mut cols = self
129 .collections
130 .write()
131 .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
132 let col = cols.get_mut(&collection).ok_or_else(|| {
133 VectorStoreError::Upsert(format!("collection {collection} not found"))
134 })?;
135 for p in points {
136 col.points.insert(
137 p.id,
138 StoredPoint {
139 vector: p.vector,
140 payload: p.payload,
141 },
142 );
143 }
144 Ok(())
145 })
146 }
147
148 fn search(
149 &self,
150 collection: &str,
151 vector: Vec<f32>,
152 limit: u64,
153 filter: Option<VectorFilter>,
154 ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
155 let collection = collection.to_owned();
156 Box::pin(async move {
157 let cols = self
158 .collections
159 .read()
160 .map_err(|e| VectorStoreError::Search(e.to_string()))?;
161 let col = cols.get(&collection).ok_or_else(|| {
162 VectorStoreError::Search(format!("collection {collection} not found"))
163 })?;
164
165 let empty_filter = VectorFilter::default();
166 let f = filter.as_ref().unwrap_or(&empty_filter);
167
168 let mut scored: Vec<ScoredVectorPoint> = col
169 .points
170 .iter()
171 .filter(|(_, sp)| matches_filter(&sp.payload, f))
172 .map(|(id, sp)| ScoredVectorPoint {
173 id: id.clone(),
174 score: cosine_similarity(&vector, &sp.vector),
175 payload: sp.payload.clone(),
176 })
177 .collect();
178
179 scored.sort_by(|a, b| {
180 b.score
181 .partial_cmp(&a.score)
182 .unwrap_or(std::cmp::Ordering::Equal)
183 });
184 #[expect(clippy::cast_possible_truncation)]
185 scored.truncate(limit as usize);
186 Ok(scored)
187 })
188 }
189
190 fn delete_by_ids(
191 &self,
192 collection: &str,
193 ids: Vec<String>,
194 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
195 let collection = collection.to_owned();
196 Box::pin(async move {
197 if ids.is_empty() {
198 return Ok(());
199 }
200 let mut cols = self
201 .collections
202 .write()
203 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
204 let col = cols.get_mut(&collection).ok_or_else(|| {
205 VectorStoreError::Delete(format!("collection {collection} not found"))
206 })?;
207 for id in &ids {
208 col.points.remove(id);
209 }
210 Ok(())
211 })
212 }
213
214 fn scroll_all(
215 &self,
216 collection: &str,
217 key_field: &str,
218 ) -> BoxFuture<'_, Result<HashMap<String, HashMap<String, String>>, VectorStoreError>> {
219 let collection = collection.to_owned();
220 let key_field = key_field.to_owned();
221 Box::pin(async move {
222 let cols = self
223 .collections
224 .read()
225 .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
226 let col = cols.get(&collection).ok_or_else(|| {
227 VectorStoreError::Scroll(format!("collection {collection} not found"))
228 })?;
229
230 let mut result = HashMap::new();
231 for sp in col.points.values() {
232 let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
233 continue;
234 };
235 let mut fields = HashMap::new();
236 for (k, v) in &sp.payload {
237 if let Some(s) = v.as_str() {
238 fields.insert(k.clone(), s.to_owned());
239 }
240 }
241 result.insert(key_val.to_owned(), fields);
242 }
243 Ok(result)
244 })
245 }
246
247 fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
248 Box::pin(async { Ok(true) })
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[tokio::test]
257 async fn ensure_collection_and_exists() {
258 let store = InMemoryVectorStore::new();
259 assert!(!store.collection_exists("test").await.unwrap());
260 store.ensure_collection("test", 3).await.unwrap();
261 assert!(store.collection_exists("test").await.unwrap());
262 }
263
264 #[tokio::test]
265 async fn ensure_collection_idempotent() {
266 let store = InMemoryVectorStore::new();
267 store.ensure_collection("test", 3).await.unwrap();
268 store.ensure_collection("test", 3).await.unwrap();
269 assert!(store.collection_exists("test").await.unwrap());
270 }
271
272 #[tokio::test]
273 async fn delete_collection_removes() {
274 let store = InMemoryVectorStore::new();
275 store.ensure_collection("test", 3).await.unwrap();
276 store.delete_collection("test").await.unwrap();
277 assert!(!store.collection_exists("test").await.unwrap());
278 }
279
280 #[tokio::test]
281 async fn upsert_and_search() {
282 let store = InMemoryVectorStore::new();
283 store.ensure_collection("test", 3).await.unwrap();
284
285 let points = vec![
286 VectorPoint {
287 id: "a".into(),
288 vector: vec![1.0, 0.0, 0.0],
289 payload: HashMap::from([("name".into(), serde_json::json!("alpha"))]),
290 },
291 VectorPoint {
292 id: "b".into(),
293 vector: vec![0.0, 1.0, 0.0],
294 payload: HashMap::from([("name".into(), serde_json::json!("beta"))]),
295 },
296 ];
297 store.upsert("test", points).await.unwrap();
298
299 let results = store
300 .search("test", vec![1.0, 0.0, 0.0], 2, None)
301 .await
302 .unwrap();
303 assert_eq!(results.len(), 2);
304 assert_eq!(results[0].id, "a");
305 assert!((results[0].score - 1.0).abs() < f32::EPSILON);
306 }
307
308 #[tokio::test]
309 async fn search_with_filter() {
310 let store = InMemoryVectorStore::new();
311 store.ensure_collection("test", 3).await.unwrap();
312
313 let points = vec![
314 VectorPoint {
315 id: "a".into(),
316 vector: vec![1.0, 0.0, 0.0],
317 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
318 },
319 VectorPoint {
320 id: "b".into(),
321 vector: vec![0.9, 0.1, 0.0],
322 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
323 },
324 ];
325 store.upsert("test", points).await.unwrap();
326
327 let filter = VectorFilter {
328 must: vec![crate::vector_store::FieldCondition {
329 field: "role".into(),
330 value: FieldValue::Text("user".into()),
331 }],
332 must_not: vec![],
333 };
334 let results = store
335 .search("test", vec![1.0, 0.0, 0.0], 10, Some(filter))
336 .await
337 .unwrap();
338 assert_eq!(results.len(), 1);
339 assert_eq!(results[0].id, "a");
340 }
341
342 #[tokio::test]
343 async fn delete_by_ids_removes_points() {
344 let store = InMemoryVectorStore::new();
345 store.ensure_collection("test", 3).await.unwrap();
346
347 let points = vec![VectorPoint {
348 id: "a".into(),
349 vector: vec![1.0, 0.0, 0.0],
350 payload: HashMap::new(),
351 }];
352 store.upsert("test", points).await.unwrap();
353 store.delete_by_ids("test", vec!["a".into()]).await.unwrap();
354
355 let results = store
356 .search("test", vec![1.0, 0.0, 0.0], 10, None)
357 .await
358 .unwrap();
359 assert!(results.is_empty());
360 }
361
362 #[tokio::test]
363 async fn scroll_all_extracts_strings() {
364 let store = InMemoryVectorStore::new();
365 store.ensure_collection("test", 3).await.unwrap();
366
367 let points = vec![VectorPoint {
368 id: "a".into(),
369 vector: vec![1.0, 0.0, 0.0],
370 payload: HashMap::from([
371 ("name".into(), serde_json::json!("alpha")),
372 ("desc".into(), serde_json::json!("first")),
373 ("num".into(), serde_json::json!(42)),
374 ]),
375 }];
376 store.upsert("test", points).await.unwrap();
377
378 let result = store.scroll_all("test", "name").await.unwrap();
379 assert_eq!(result.len(), 1);
380 let fields = result.get("alpha").unwrap();
381 assert_eq!(fields.get("desc").unwrap(), "first");
382 assert!(!fields.contains_key("num"));
383 }
384
385 #[test]
386 fn cosine_similarity_import_wired() {
387 assert!(!cosine_similarity(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).is_nan());
389 }
390
391 #[tokio::test]
392 async fn default_impl() {
393 let store = InMemoryVectorStore::default();
394 assert!(!store.collection_exists("any").await.unwrap());
395 }
396
397 #[test]
398 fn debug_format() {
399 let store = InMemoryVectorStore::new();
400 let dbg = format!("{store:?}");
401 assert!(dbg.contains("InMemoryVectorStore"));
402 }
403}