1use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::RwLock;
8
9use crate::vector_store::{
10 FieldValue, ScoredVectorPoint, VectorFilter, VectorPoint, VectorStore, VectorStoreError,
11};
12
13type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
14
15struct StoredPoint {
16 vector: Vec<f32>,
17 payload: HashMap<String, serde_json::Value>,
18}
19
20struct InMemoryCollection {
21 points: HashMap<String, StoredPoint>,
22}
23
24pub struct InMemoryVectorStore {
25 collections: RwLock<HashMap<String, InMemoryCollection>>,
26}
27
28impl InMemoryVectorStore {
29 #[must_use]
30 pub fn new() -> Self {
31 Self {
32 collections: RwLock::new(HashMap::new()),
33 }
34 }
35}
36
37impl Default for InMemoryVectorStore {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl std::fmt::Debug for InMemoryVectorStore {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.debug_struct("InMemoryVectorStore")
46 .finish_non_exhaustive()
47 }
48}
49
50fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
51 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
52 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
53 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
54 if norm_a == 0.0 || norm_b == 0.0 {
55 return 0.0;
56 }
57 dot / (norm_a * norm_b)
58}
59
60fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
61 for cond in &filter.must {
62 let Some(val) = payload.get(&cond.field) else {
63 return false;
64 };
65 if !field_matches(val, &cond.value) {
66 return false;
67 }
68 }
69 for cond in &filter.must_not {
70 if let Some(val) = payload.get(&cond.field)
71 && field_matches(val, &cond.value)
72 {
73 return false;
74 }
75 }
76 true
77}
78
79fn field_matches(val: &serde_json::Value, expected: &FieldValue) -> bool {
80 match expected {
81 FieldValue::Integer(i) => val.as_i64() == Some(*i),
82 FieldValue::Text(s) => val.as_str() == Some(s.as_str()),
83 }
84}
85
86impl VectorStore for InMemoryVectorStore {
87 fn ensure_collection(
88 &self,
89 collection: &str,
90 _vector_size: u64,
91 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
92 let collection = collection.to_owned();
93 Box::pin(async move {
94 let mut cols = self
95 .collections
96 .write()
97 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
98 cols.entry(collection)
99 .or_insert_with(|| InMemoryCollection {
100 points: HashMap::new(),
101 });
102 Ok(())
103 })
104 }
105
106 fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
107 let collection = collection.to_owned();
108 Box::pin(async move {
109 let cols = self
110 .collections
111 .read()
112 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
113 Ok(cols.contains_key(&collection))
114 })
115 }
116
117 fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
118 let collection = collection.to_owned();
119 Box::pin(async move {
120 let mut cols = self
121 .collections
122 .write()
123 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
124 cols.remove(&collection);
125 Ok(())
126 })
127 }
128
129 fn upsert(
130 &self,
131 collection: &str,
132 points: Vec<VectorPoint>,
133 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
134 let collection = collection.to_owned();
135 Box::pin(async move {
136 let mut cols = self
137 .collections
138 .write()
139 .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
140 let col = cols.get_mut(&collection).ok_or_else(|| {
141 VectorStoreError::Upsert(format!("collection {collection} not found"))
142 })?;
143 for p in points {
144 col.points.insert(
145 p.id,
146 StoredPoint {
147 vector: p.vector,
148 payload: p.payload,
149 },
150 );
151 }
152 Ok(())
153 })
154 }
155
156 fn search(
157 &self,
158 collection: &str,
159 vector: Vec<f32>,
160 limit: u64,
161 filter: Option<VectorFilter>,
162 ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
163 let collection = collection.to_owned();
164 Box::pin(async move {
165 let cols = self
166 .collections
167 .read()
168 .map_err(|e| VectorStoreError::Search(e.to_string()))?;
169 let col = cols.get(&collection).ok_or_else(|| {
170 VectorStoreError::Search(format!("collection {collection} not found"))
171 })?;
172
173 let empty_filter = VectorFilter::default();
174 let f = filter.as_ref().unwrap_or(&empty_filter);
175
176 let mut scored: Vec<ScoredVectorPoint> = col
177 .points
178 .iter()
179 .filter(|(_, sp)| matches_filter(&sp.payload, f))
180 .map(|(id, sp)| ScoredVectorPoint {
181 id: id.clone(),
182 score: cosine_similarity(&vector, &sp.vector),
183 payload: sp.payload.clone(),
184 })
185 .collect();
186
187 scored.sort_by(|a, b| {
188 b.score
189 .partial_cmp(&a.score)
190 .unwrap_or(std::cmp::Ordering::Equal)
191 });
192 #[expect(clippy::cast_possible_truncation)]
193 scored.truncate(limit as usize);
194 Ok(scored)
195 })
196 }
197
198 fn delete_by_ids(
199 &self,
200 collection: &str,
201 ids: Vec<String>,
202 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
203 let collection = collection.to_owned();
204 Box::pin(async move {
205 if ids.is_empty() {
206 return Ok(());
207 }
208 let mut cols = self
209 .collections
210 .write()
211 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
212 let col = cols.get_mut(&collection).ok_or_else(|| {
213 VectorStoreError::Delete(format!("collection {collection} not found"))
214 })?;
215 for id in &ids {
216 col.points.remove(id);
217 }
218 Ok(())
219 })
220 }
221
222 fn scroll_all(
223 &self,
224 collection: &str,
225 key_field: &str,
226 ) -> BoxFuture<'_, Result<HashMap<String, HashMap<String, String>>, VectorStoreError>> {
227 let collection = collection.to_owned();
228 let key_field = key_field.to_owned();
229 Box::pin(async move {
230 let cols = self
231 .collections
232 .read()
233 .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
234 let col = cols.get(&collection).ok_or_else(|| {
235 VectorStoreError::Scroll(format!("collection {collection} not found"))
236 })?;
237
238 let mut result = HashMap::new();
239 for sp in col.points.values() {
240 let Some(key_val) = sp.payload.get(&key_field).and_then(|v| v.as_str()) else {
241 continue;
242 };
243 let mut fields = HashMap::new();
244 for (k, v) in &sp.payload {
245 if let Some(s) = v.as_str() {
246 fields.insert(k.clone(), s.to_owned());
247 }
248 }
249 result.insert(key_val.to_owned(), fields);
250 }
251 Ok(result)
252 })
253 }
254
255 fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
256 Box::pin(async { Ok(true) })
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[tokio::test]
265 async fn ensure_collection_and_exists() {
266 let store = InMemoryVectorStore::new();
267 assert!(!store.collection_exists("test").await.unwrap());
268 store.ensure_collection("test", 3).await.unwrap();
269 assert!(store.collection_exists("test").await.unwrap());
270 }
271
272 #[tokio::test]
273 async fn ensure_collection_idempotent() {
274 let store = InMemoryVectorStore::new();
275 store.ensure_collection("test", 3).await.unwrap();
276 store.ensure_collection("test", 3).await.unwrap();
277 assert!(store.collection_exists("test").await.unwrap());
278 }
279
280 #[tokio::test]
281 async fn delete_collection_removes() {
282 let store = InMemoryVectorStore::new();
283 store.ensure_collection("test", 3).await.unwrap();
284 store.delete_collection("test").await.unwrap();
285 assert!(!store.collection_exists("test").await.unwrap());
286 }
287
288 #[tokio::test]
289 async fn upsert_and_search() {
290 let store = InMemoryVectorStore::new();
291 store.ensure_collection("test", 3).await.unwrap();
292
293 let points = vec![
294 VectorPoint {
295 id: "a".into(),
296 vector: vec![1.0, 0.0, 0.0],
297 payload: HashMap::from([("name".into(), serde_json::json!("alpha"))]),
298 },
299 VectorPoint {
300 id: "b".into(),
301 vector: vec![0.0, 1.0, 0.0],
302 payload: HashMap::from([("name".into(), serde_json::json!("beta"))]),
303 },
304 ];
305 store.upsert("test", points).await.unwrap();
306
307 let results = store
308 .search("test", vec![1.0, 0.0, 0.0], 2, None)
309 .await
310 .unwrap();
311 assert_eq!(results.len(), 2);
312 assert_eq!(results[0].id, "a");
313 assert!((results[0].score - 1.0).abs() < f32::EPSILON);
314 }
315
316 #[tokio::test]
317 async fn search_with_filter() {
318 let store = InMemoryVectorStore::new();
319 store.ensure_collection("test", 3).await.unwrap();
320
321 let points = vec![
322 VectorPoint {
323 id: "a".into(),
324 vector: vec![1.0, 0.0, 0.0],
325 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
326 },
327 VectorPoint {
328 id: "b".into(),
329 vector: vec![0.9, 0.1, 0.0],
330 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
331 },
332 ];
333 store.upsert("test", points).await.unwrap();
334
335 let filter = VectorFilter {
336 must: vec![crate::vector_store::FieldCondition {
337 field: "role".into(),
338 value: FieldValue::Text("user".into()),
339 }],
340 must_not: vec![],
341 };
342 let results = store
343 .search("test", vec![1.0, 0.0, 0.0], 10, Some(filter))
344 .await
345 .unwrap();
346 assert_eq!(results.len(), 1);
347 assert_eq!(results[0].id, "a");
348 }
349
350 #[tokio::test]
351 async fn delete_by_ids_removes_points() {
352 let store = InMemoryVectorStore::new();
353 store.ensure_collection("test", 3).await.unwrap();
354
355 let points = vec![VectorPoint {
356 id: "a".into(),
357 vector: vec![1.0, 0.0, 0.0],
358 payload: HashMap::new(),
359 }];
360 store.upsert("test", points).await.unwrap();
361 store.delete_by_ids("test", vec!["a".into()]).await.unwrap();
362
363 let results = store
364 .search("test", vec![1.0, 0.0, 0.0], 10, None)
365 .await
366 .unwrap();
367 assert!(results.is_empty());
368 }
369
370 #[tokio::test]
371 async fn scroll_all_extracts_strings() {
372 let store = InMemoryVectorStore::new();
373 store.ensure_collection("test", 3).await.unwrap();
374
375 let points = vec![VectorPoint {
376 id: "a".into(),
377 vector: vec![1.0, 0.0, 0.0],
378 payload: HashMap::from([
379 ("name".into(), serde_json::json!("alpha")),
380 ("desc".into(), serde_json::json!("first")),
381 ("num".into(), serde_json::json!(42)),
382 ]),
383 }];
384 store.upsert("test", points).await.unwrap();
385
386 let result = store.scroll_all("test", "name").await.unwrap();
387 assert_eq!(result.len(), 1);
388 let fields = result.get("alpha").unwrap();
389 assert_eq!(fields.get("desc").unwrap(), "first");
390 assert!(!fields.contains_key("num"));
391 }
392
393 #[test]
394 fn cosine_similarity_orthogonal() {
395 let a = vec![1.0, 0.0, 0.0];
396 let b = vec![0.0, 1.0, 0.0];
397 assert!((cosine_similarity(&a, &b)).abs() < f32::EPSILON);
398 }
399
400 #[tokio::test]
401 async fn default_impl() {
402 let store = InMemoryVectorStore::default();
403 assert!(!store.collection_exists("any").await.unwrap());
404 }
405
406 #[test]
407 fn debug_format() {
408 let store = InMemoryVectorStore::new();
409 let dbg = format!("{store:?}");
410 assert!(dbg.contains("InMemoryVectorStore"));
411 }
412}