skill_runtime/vector_store/
in_memory.rs1use super::{
22 cosine_similarity, euclidean_distance, DeleteStats, DistanceMetric, EmbeddedDocument, Filter,
23 HealthStatus, SearchResult, UpsertStats, VectorStore,
24};
25use anyhow::Result;
26use async_trait::async_trait;
27use std::collections::HashMap;
28use std::sync::RwLock;
29use std::time::Instant;
30
31pub struct InMemoryVectorStore {
36 documents: RwLock<HashMap<String, EmbeddedDocument>>,
38
39 distance_metric: DistanceMetric,
41
42 dimensions: Option<usize>,
44}
45
46impl InMemoryVectorStore {
47 pub fn new() -> Self {
49 Self {
50 documents: RwLock::new(HashMap::new()),
51 distance_metric: DistanceMetric::Cosine,
52 dimensions: None,
53 }
54 }
55
56 pub fn with_metric(metric: DistanceMetric) -> Self {
58 Self {
59 documents: RwLock::new(HashMap::new()),
60 distance_metric: metric,
61 dimensions: None,
62 }
63 }
64
65 pub fn with_dimensions(dimensions: usize) -> Self {
67 Self {
68 documents: RwLock::new(HashMap::new()),
69 distance_metric: DistanceMetric::Cosine,
70 dimensions: Some(dimensions),
71 }
72 }
73
74 pub fn with_config(metric: DistanceMetric, dimensions: usize) -> Self {
76 Self {
77 documents: RwLock::new(HashMap::new()),
78 distance_metric: metric,
79 dimensions: Some(dimensions),
80 }
81 }
82
83 fn calculate_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
85 match self.distance_metric {
86 DistanceMetric::Cosine => cosine_similarity(a, b),
87 DistanceMetric::Euclidean => {
88 let dist = euclidean_distance(a, b);
90 1.0 / (1.0 + dist)
91 }
92 DistanceMetric::DotProduct => {
93 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
95 }
96 }
97 }
98
99 fn validate_dimensions(&self, embedding: &[f32]) -> Result<()> {
101 if let Some(expected) = self.dimensions {
102 if embedding.len() != expected {
103 anyhow::bail!(
104 "Embedding dimension mismatch: expected {}, got {}",
105 expected,
106 embedding.len()
107 );
108 }
109 }
110 Ok(())
111 }
112
113 fn document_count(&self) -> usize {
115 self.documents.read().unwrap().len()
116 }
117
118 pub fn clear(&self) {
120 let mut docs = self.documents.write().unwrap();
121 docs.clear();
122 }
123}
124
125impl Default for InMemoryVectorStore {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131#[async_trait]
132impl VectorStore for InMemoryVectorStore {
133 async fn upsert(&self, documents: Vec<EmbeddedDocument>) -> Result<UpsertStats> {
134 let start = Instant::now();
135 let mut inserted = 0;
136 let mut updated = 0;
137
138 for doc in &documents {
140 self.validate_dimensions(&doc.embedding)?;
141 }
142
143 let mut store = self.documents.write().unwrap();
145 for doc in documents {
146 if store.contains_key(&doc.id) {
147 updated += 1;
148 } else {
149 inserted += 1;
150 }
151 store.insert(doc.id.clone(), doc);
152 }
153
154 Ok(UpsertStats::new(inserted, updated, start.elapsed().as_millis() as u64))
155 }
156
157 async fn search(
158 &self,
159 query_embedding: Vec<f32>,
160 filter: Option<Filter>,
161 top_k: usize,
162 ) -> Result<Vec<SearchResult>> {
163 self.validate_dimensions(&query_embedding)?;
164
165 let store = self.documents.read().unwrap();
166
167 let mut scored: Vec<(f32, &EmbeddedDocument)> = store
169 .values()
170 .filter(|doc| {
171 filter
173 .as_ref()
174 .map_or(true, |f| f.matches(&doc.metadata))
175 })
176 .map(|doc| {
177 let score = self.calculate_similarity(&query_embedding, &doc.embedding);
178 (score, doc)
179 })
180 .filter(|(score, _)| {
181 filter
183 .as_ref()
184 .and_then(|f| f.min_score)
185 .map_or(true, |min| *score >= min)
186 })
187 .collect();
188
189 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
191
192 let results: Vec<SearchResult> = scored
194 .into_iter()
195 .take(top_k)
196 .map(|(score, doc)| SearchResult::from_document(doc, score))
197 .collect();
198
199 Ok(results)
200 }
201
202 async fn delete(&self, ids: Vec<String>) -> Result<DeleteStats> {
203 let start = Instant::now();
204 let mut deleted = 0;
205 let mut not_found = 0;
206
207 let mut store = self.documents.write().unwrap();
208 for id in &ids {
209 if store.remove(id).is_some() {
210 deleted += 1;
211 } else {
212 not_found += 1;
213 }
214 }
215
216 Ok(DeleteStats::new(deleted, not_found, start.elapsed().as_millis() as u64))
217 }
218
219 async fn get(&self, ids: Vec<String>) -> Result<Vec<EmbeddedDocument>> {
220 let store = self.documents.read().unwrap();
221 let results: Vec<EmbeddedDocument> = ids
222 .iter()
223 .filter_map(|id| store.get(id).cloned())
224 .collect();
225 Ok(results)
226 }
227
228 async fn count(&self, filter: Option<Filter>) -> Result<usize> {
229 let store = self.documents.read().unwrap();
230 let count = match filter {
231 Some(f) if !f.is_empty() => store.values().filter(|doc| f.matches(&doc.metadata)).count(),
232 _ => store.len(),
233 };
234 Ok(count)
235 }
236
237 async fn health_check(&self) -> Result<HealthStatus> {
238 let start = Instant::now();
239 let count = self.document_count();
240 let latency = start.elapsed().as_millis() as u64;
241
242 Ok(HealthStatus::healthy("in_memory", latency).with_document_count(count))
243 }
244
245 fn backend_name(&self) -> &'static str {
246 "in_memory"
247 }
248
249 fn dimensions(&self) -> Option<usize> {
250 self.dimensions
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 fn create_test_documents() -> Vec<EmbeddedDocument> {
259 vec![
260 EmbeddedDocument::new("doc1", vec![1.0, 0.0, 0.0])
261 .with_skill_name("kubernetes")
262 .with_tool_name("get_pods")
263 .with_tags(vec!["k8s".to_string()]),
264 EmbeddedDocument::new("doc2", vec![0.9, 0.1, 0.0])
265 .with_skill_name("kubernetes")
266 .with_tool_name("create_deployment")
267 .with_tags(vec!["k8s".to_string()]),
268 EmbeddedDocument::new("doc3", vec![0.0, 1.0, 0.0])
269 .with_skill_name("aws")
270 .with_tool_name("list_buckets")
271 .with_tags(vec!["cloud".to_string()]),
272 EmbeddedDocument::new("doc4", vec![0.0, 0.0, 1.0])
273 .with_skill_name("git")
274 .with_tool_name("commit")
275 .with_tags(vec!["vcs".to_string()]),
276 ]
277 }
278
279 #[tokio::test]
280 async fn test_upsert_and_count() {
281 let store = InMemoryVectorStore::new();
282 let docs = create_test_documents();
283
284 let stats = store.upsert(docs).await.unwrap();
285 assert_eq!(stats.inserted, 4);
286 assert_eq!(stats.updated, 0);
287 assert_eq!(stats.total, 4);
288
289 let count = store.count(None).await.unwrap();
290 assert_eq!(count, 4);
291 }
292
293 #[tokio::test]
294 async fn test_upsert_update() {
295 let store = InMemoryVectorStore::new();
296
297 let docs = vec![EmbeddedDocument::new("doc1", vec![1.0, 0.0, 0.0])];
299 let stats = store.upsert(docs).await.unwrap();
300 assert_eq!(stats.inserted, 1);
301 assert_eq!(stats.updated, 0);
302
303 let docs = vec![EmbeddedDocument::new("doc1", vec![0.0, 1.0, 0.0])];
305 let stats = store.upsert(docs).await.unwrap();
306 assert_eq!(stats.inserted, 0);
307 assert_eq!(stats.updated, 1);
308
309 let count = store.count(None).await.unwrap();
311 assert_eq!(count, 1);
312 }
313
314 #[tokio::test]
315 async fn test_search_basic() {
316 let store = InMemoryVectorStore::new();
317 store.upsert(create_test_documents()).await.unwrap();
318
319 let results = store
321 .search(vec![1.0, 0.0, 0.0], None, 2)
322 .await
323 .unwrap();
324
325 assert_eq!(results.len(), 2);
326 assert_eq!(results[0].id, "doc1"); assert!((results[0].score - 1.0).abs() < 1e-5); assert_eq!(results[1].id, "doc2"); }
330
331 #[tokio::test]
332 async fn test_search_with_filter() {
333 let store = InMemoryVectorStore::new();
334 store.upsert(create_test_documents()).await.unwrap();
335
336 let filter = Filter::new().skill("kubernetes");
338 let results = store
339 .search(vec![0.5, 0.5, 0.0], Some(filter), 10)
340 .await
341 .unwrap();
342
343 assert_eq!(results.len(), 2);
344 for result in results {
345 assert_eq!(result.metadata.skill_name, Some("kubernetes".to_string()));
346 }
347 }
348
349 #[tokio::test]
350 async fn test_search_with_tag_filter() {
351 let store = InMemoryVectorStore::new();
352 store.upsert(create_test_documents()).await.unwrap();
353
354 let filter = Filter::new().tags(vec!["k8s".to_string()]);
356 let results = store
357 .search(vec![0.5, 0.5, 0.0], Some(filter), 10)
358 .await
359 .unwrap();
360
361 assert_eq!(results.len(), 2);
362 for result in results {
363 assert!(result.metadata.tags.contains(&"k8s".to_string()));
364 }
365 }
366
367 #[tokio::test]
368 async fn test_search_with_min_score() {
369 let store = InMemoryVectorStore::new();
370 store.upsert(create_test_documents()).await.unwrap();
371
372 let filter = Filter::new().min_score(0.9999);
374 let results = store
375 .search(vec![1.0, 0.0, 0.0], Some(filter), 10)
376 .await
377 .unwrap();
378
379 assert_eq!(results.len(), 1);
381 assert_eq!(results[0].id, "doc1");
382
383 let filter = Filter::new().min_score(0.8);
385 let results = store
386 .search(vec![1.0, 0.0, 0.0], Some(filter), 10)
387 .await
388 .unwrap();
389
390 assert_eq!(results.len(), 2);
392 }
393
394 #[tokio::test]
395 async fn test_delete() {
396 let store = InMemoryVectorStore::new();
397 store.upsert(create_test_documents()).await.unwrap();
398
399 let stats = store
400 .delete(vec!["doc1".to_string(), "doc2".to_string(), "nonexistent".to_string()])
401 .await
402 .unwrap();
403
404 assert_eq!(stats.deleted, 2);
405 assert_eq!(stats.not_found, 1);
406
407 let count = store.count(None).await.unwrap();
408 assert_eq!(count, 2);
409 }
410
411 #[tokio::test]
412 async fn test_get() {
413 let store = InMemoryVectorStore::new();
414 store.upsert(create_test_documents()).await.unwrap();
415
416 let docs = store
417 .get(vec!["doc1".to_string(), "doc3".to_string(), "nonexistent".to_string()])
418 .await
419 .unwrap();
420
421 assert_eq!(docs.len(), 2);
422 assert!(docs.iter().any(|d| d.id == "doc1"));
423 assert!(docs.iter().any(|d| d.id == "doc3"));
424 }
425
426 #[tokio::test]
427 async fn test_count_with_filter() {
428 let store = InMemoryVectorStore::new();
429 store.upsert(create_test_documents()).await.unwrap();
430
431 let filter = Filter::new().skill("kubernetes");
432 let count = store.count(Some(filter)).await.unwrap();
433 assert_eq!(count, 2);
434
435 let filter = Filter::new().skill("git");
436 let count = store.count(Some(filter)).await.unwrap();
437 assert_eq!(count, 1);
438 }
439
440 #[tokio::test]
441 async fn test_health_check() {
442 let store = InMemoryVectorStore::new();
443 store.upsert(create_test_documents()).await.unwrap();
444
445 let status = store.health_check().await.unwrap();
446 assert!(status.healthy);
447 assert_eq!(status.backend, "in_memory");
448 assert_eq!(status.document_count, Some(4));
449 }
450
451 #[tokio::test]
452 async fn test_dimension_validation() {
453 let store = InMemoryVectorStore::with_dimensions(3);
454
455 let docs = vec![EmbeddedDocument::new("doc1", vec![1.0, 0.0, 0.0])];
457 assert!(store.upsert(docs).await.is_ok());
458
459 let docs = vec![EmbeddedDocument::new("doc2", vec![1.0, 0.0])];
461 assert!(store.upsert(docs).await.is_err());
462 }
463
464 #[tokio::test]
465 async fn test_euclidean_metric() {
466 let store = InMemoryVectorStore::with_metric(DistanceMetric::Euclidean);
467 store.upsert(create_test_documents()).await.unwrap();
468
469 let results = store
470 .search(vec![1.0, 0.0, 0.0], None, 2)
471 .await
472 .unwrap();
473
474 assert_eq!(results[0].id, "doc1");
476 }
477
478 #[tokio::test]
479 async fn test_clear() {
480 let store = InMemoryVectorStore::new();
481 store.upsert(create_test_documents()).await.unwrap();
482
483 assert_eq!(store.count(None).await.unwrap(), 4);
484
485 store.clear();
486
487 assert_eq!(store.count(None).await.unwrap(), 0);
488 }
489
490 #[tokio::test]
491 async fn test_backend_name() {
492 let store = InMemoryVectorStore::new();
493 assert_eq!(store.backend_name(), "in_memory");
494 }
495}