skill_runtime/vector_store/
in_memory.rs

1//! In-memory vector store implementation
2//!
3//! This is the default backend that stores all vectors in memory.
4//! Suitable for development, testing, and small-to-medium workloads.
5//!
6//! # Features
7//!
8//! - Zero external dependencies
9//! - Fast for small datasets (<10k documents)
10//! - Thread-safe with RwLock
11//! - Supports all filter operations
12//! - Cosine, Euclidean, and Dot Product distance metrics
13//!
14//! # Limitations
15//!
16//! - All data is lost on process restart (no persistence)
17//! - Memory usage grows linearly with documents
18//! - O(n) search complexity (no indexing)
19//! - Not suitable for >100k documents
20
21use super::{
22    cosine_similarity, euclidean_distance, DeleteStats, DistanceMetric, EmbeddedDocument, Filter,
23    HealthStatus, SearchResult, UpsertStats, VectorStore,
24};
25use anyhow::Result;
26use async_trait::async_trait;
27use std::collections::HashMap;
28use std::sync::RwLock;
29use std::time::Instant;
30
31/// In-memory vector store implementation
32///
33/// Stores documents in a HashMap protected by RwLock for thread-safety.
34/// Uses brute-force similarity search (suitable for small datasets).
35pub struct InMemoryVectorStore {
36    /// Document storage: id -> document
37    documents: RwLock<HashMap<String, EmbeddedDocument>>,
38
39    /// Distance metric to use for similarity
40    distance_metric: DistanceMetric,
41
42    /// Expected vector dimensions (for validation)
43    dimensions: Option<usize>,
44}
45
46impl InMemoryVectorStore {
47    /// Create a new in-memory vector store with default settings
48    pub fn new() -> Self {
49        Self {
50            documents: RwLock::new(HashMap::new()),
51            distance_metric: DistanceMetric::Cosine,
52            dimensions: None,
53        }
54    }
55
56    /// Create with a specific distance metric
57    pub fn with_metric(metric: DistanceMetric) -> Self {
58        Self {
59            documents: RwLock::new(HashMap::new()),
60            distance_metric: metric,
61            dimensions: None,
62        }
63    }
64
65    /// Create with expected dimensions for validation
66    pub fn with_dimensions(dimensions: usize) -> Self {
67        Self {
68            documents: RwLock::new(HashMap::new()),
69            distance_metric: DistanceMetric::Cosine,
70            dimensions: Some(dimensions),
71        }
72    }
73
74    /// Create with both metric and dimensions
75    pub fn with_config(metric: DistanceMetric, dimensions: usize) -> Self {
76        Self {
77            documents: RwLock::new(HashMap::new()),
78            distance_metric: metric,
79            dimensions: Some(dimensions),
80        }
81    }
82
83    /// Calculate similarity between two vectors based on configured metric
84    fn calculate_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
85        match self.distance_metric {
86            DistanceMetric::Cosine => cosine_similarity(a, b),
87            DistanceMetric::Euclidean => {
88                // Convert distance to similarity (0 distance = 1 similarity)
89                let dist = euclidean_distance(a, b);
90                1.0 / (1.0 + dist)
91            }
92            DistanceMetric::DotProduct => {
93                // Dot product directly (assumes normalized vectors)
94                a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
95            }
96        }
97    }
98
99    /// Validate document dimensions
100    fn validate_dimensions(&self, embedding: &[f32]) -> Result<()> {
101        if let Some(expected) = self.dimensions {
102            if embedding.len() != expected {
103                anyhow::bail!(
104                    "Embedding dimension mismatch: expected {}, got {}",
105                    expected,
106                    embedding.len()
107                );
108            }
109        }
110        Ok(())
111    }
112
113    /// Get current document count (sync version for internal use)
114    fn document_count(&self) -> usize {
115        self.documents.read().unwrap().len()
116    }
117
118    /// Clear all documents
119    pub fn clear(&self) {
120        let mut docs = self.documents.write().unwrap();
121        docs.clear();
122    }
123}
124
125impl Default for InMemoryVectorStore {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131#[async_trait]
132impl VectorStore for InMemoryVectorStore {
133    async fn upsert(&self, documents: Vec<EmbeddedDocument>) -> Result<UpsertStats> {
134        let start = Instant::now();
135        let mut inserted = 0;
136        let mut updated = 0;
137
138        // Validate all documents first
139        for doc in &documents {
140            self.validate_dimensions(&doc.embedding)?;
141        }
142
143        // Insert/update documents
144        let mut store = self.documents.write().unwrap();
145        for doc in documents {
146            if store.contains_key(&doc.id) {
147                updated += 1;
148            } else {
149                inserted += 1;
150            }
151            store.insert(doc.id.clone(), doc);
152        }
153
154        Ok(UpsertStats::new(inserted, updated, start.elapsed().as_millis() as u64))
155    }
156
157    async fn search(
158        &self,
159        query_embedding: Vec<f32>,
160        filter: Option<Filter>,
161        top_k: usize,
162    ) -> Result<Vec<SearchResult>> {
163        self.validate_dimensions(&query_embedding)?;
164
165        let store = self.documents.read().unwrap();
166
167        // Calculate similarity for all documents
168        let mut scored: Vec<(f32, &EmbeddedDocument)> = store
169            .values()
170            .filter(|doc| {
171                // Apply metadata filter
172                filter
173                    .as_ref()
174                    .map_or(true, |f| f.matches(&doc.metadata))
175            })
176            .map(|doc| {
177                let score = self.calculate_similarity(&query_embedding, &doc.embedding);
178                (score, doc)
179            })
180            .filter(|(score, _)| {
181                // Apply min_score filter
182                filter
183                    .as_ref()
184                    .and_then(|f| f.min_score)
185                    .map_or(true, |min| *score >= min)
186            })
187            .collect();
188
189        // Sort by descending score
190        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
191
192        // Take top_k and convert to SearchResult
193        let results: Vec<SearchResult> = scored
194            .into_iter()
195            .take(top_k)
196            .map(|(score, doc)| SearchResult::from_document(doc, score))
197            .collect();
198
199        Ok(results)
200    }
201
202    async fn delete(&self, ids: Vec<String>) -> Result<DeleteStats> {
203        let start = Instant::now();
204        let mut deleted = 0;
205        let mut not_found = 0;
206
207        let mut store = self.documents.write().unwrap();
208        for id in &ids {
209            if store.remove(id).is_some() {
210                deleted += 1;
211            } else {
212                not_found += 1;
213            }
214        }
215
216        Ok(DeleteStats::new(deleted, not_found, start.elapsed().as_millis() as u64))
217    }
218
219    async fn get(&self, ids: Vec<String>) -> Result<Vec<EmbeddedDocument>> {
220        let store = self.documents.read().unwrap();
221        let results: Vec<EmbeddedDocument> = ids
222            .iter()
223            .filter_map(|id| store.get(id).cloned())
224            .collect();
225        Ok(results)
226    }
227
228    async fn count(&self, filter: Option<Filter>) -> Result<usize> {
229        let store = self.documents.read().unwrap();
230        let count = match filter {
231            Some(f) if !f.is_empty() => store.values().filter(|doc| f.matches(&doc.metadata)).count(),
232            _ => store.len(),
233        };
234        Ok(count)
235    }
236
237    async fn health_check(&self) -> Result<HealthStatus> {
238        let start = Instant::now();
239        let count = self.document_count();
240        let latency = start.elapsed().as_millis() as u64;
241
242        Ok(HealthStatus::healthy("in_memory", latency).with_document_count(count))
243    }
244
245    fn backend_name(&self) -> &'static str {
246        "in_memory"
247    }
248
249    fn dimensions(&self) -> Option<usize> {
250        self.dimensions
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    fn create_test_documents() -> Vec<EmbeddedDocument> {
259        vec![
260            EmbeddedDocument::new("doc1", vec![1.0, 0.0, 0.0])
261                .with_skill_name("kubernetes")
262                .with_tool_name("get_pods")
263                .with_tags(vec!["k8s".to_string()]),
264            EmbeddedDocument::new("doc2", vec![0.9, 0.1, 0.0])
265                .with_skill_name("kubernetes")
266                .with_tool_name("create_deployment")
267                .with_tags(vec!["k8s".to_string()]),
268            EmbeddedDocument::new("doc3", vec![0.0, 1.0, 0.0])
269                .with_skill_name("aws")
270                .with_tool_name("list_buckets")
271                .with_tags(vec!["cloud".to_string()]),
272            EmbeddedDocument::new("doc4", vec![0.0, 0.0, 1.0])
273                .with_skill_name("git")
274                .with_tool_name("commit")
275                .with_tags(vec!["vcs".to_string()]),
276        ]
277    }
278
279    #[tokio::test]
280    async fn test_upsert_and_count() {
281        let store = InMemoryVectorStore::new();
282        let docs = create_test_documents();
283
284        let stats = store.upsert(docs).await.unwrap();
285        assert_eq!(stats.inserted, 4);
286        assert_eq!(stats.updated, 0);
287        assert_eq!(stats.total, 4);
288
289        let count = store.count(None).await.unwrap();
290        assert_eq!(count, 4);
291    }
292
293    #[tokio::test]
294    async fn test_upsert_update() {
295        let store = InMemoryVectorStore::new();
296
297        // Initial insert
298        let docs = vec![EmbeddedDocument::new("doc1", vec![1.0, 0.0, 0.0])];
299        let stats = store.upsert(docs).await.unwrap();
300        assert_eq!(stats.inserted, 1);
301        assert_eq!(stats.updated, 0);
302
303        // Update same document
304        let docs = vec![EmbeddedDocument::new("doc1", vec![0.0, 1.0, 0.0])];
305        let stats = store.upsert(docs).await.unwrap();
306        assert_eq!(stats.inserted, 0);
307        assert_eq!(stats.updated, 1);
308
309        // Verify count unchanged
310        let count = store.count(None).await.unwrap();
311        assert_eq!(count, 1);
312    }
313
314    #[tokio::test]
315    async fn test_search_basic() {
316        let store = InMemoryVectorStore::new();
317        store.upsert(create_test_documents()).await.unwrap();
318
319        // Search for vector similar to doc1
320        let results = store
321            .search(vec![1.0, 0.0, 0.0], None, 2)
322            .await
323            .unwrap();
324
325        assert_eq!(results.len(), 2);
326        assert_eq!(results[0].id, "doc1"); // Exact match should be first
327        assert!((results[0].score - 1.0).abs() < 1e-5); // Perfect score
328        assert_eq!(results[1].id, "doc2"); // Second most similar
329    }
330
331    #[tokio::test]
332    async fn test_search_with_filter() {
333        let store = InMemoryVectorStore::new();
334        store.upsert(create_test_documents()).await.unwrap();
335
336        // Search only kubernetes skills
337        let filter = Filter::new().skill("kubernetes");
338        let results = store
339            .search(vec![0.5, 0.5, 0.0], Some(filter), 10)
340            .await
341            .unwrap();
342
343        assert_eq!(results.len(), 2);
344        for result in results {
345            assert_eq!(result.metadata.skill_name, Some("kubernetes".to_string()));
346        }
347    }
348
349    #[tokio::test]
350    async fn test_search_with_tag_filter() {
351        let store = InMemoryVectorStore::new();
352        store.upsert(create_test_documents()).await.unwrap();
353
354        // Search only k8s tagged documents
355        let filter = Filter::new().tags(vec!["k8s".to_string()]);
356        let results = store
357            .search(vec![0.5, 0.5, 0.0], Some(filter), 10)
358            .await
359            .unwrap();
360
361        assert_eq!(results.len(), 2);
362        for result in results {
363            assert!(result.metadata.tags.contains(&"k8s".to_string()));
364        }
365    }
366
367    #[tokio::test]
368    async fn test_search_with_min_score() {
369        let store = InMemoryVectorStore::new();
370        store.upsert(create_test_documents()).await.unwrap();
371
372        // Search with very high min_score (only exact match passes)
373        let filter = Filter::new().min_score(0.9999);
374        let results = store
375            .search(vec![1.0, 0.0, 0.0], Some(filter), 10)
376            .await
377            .unwrap();
378
379        // Only exact match should pass
380        assert_eq!(results.len(), 1);
381        assert_eq!(results[0].id, "doc1");
382
383        // Search with moderate min_score
384        let filter = Filter::new().min_score(0.8);
385        let results = store
386            .search(vec![1.0, 0.0, 0.0], Some(filter), 10)
387            .await
388            .unwrap();
389
390        // doc1 (1.0) and doc2 (0.9949...) should pass
391        assert_eq!(results.len(), 2);
392    }
393
394    #[tokio::test]
395    async fn test_delete() {
396        let store = InMemoryVectorStore::new();
397        store.upsert(create_test_documents()).await.unwrap();
398
399        let stats = store
400            .delete(vec!["doc1".to_string(), "doc2".to_string(), "nonexistent".to_string()])
401            .await
402            .unwrap();
403
404        assert_eq!(stats.deleted, 2);
405        assert_eq!(stats.not_found, 1);
406
407        let count = store.count(None).await.unwrap();
408        assert_eq!(count, 2);
409    }
410
411    #[tokio::test]
412    async fn test_get() {
413        let store = InMemoryVectorStore::new();
414        store.upsert(create_test_documents()).await.unwrap();
415
416        let docs = store
417            .get(vec!["doc1".to_string(), "doc3".to_string(), "nonexistent".to_string()])
418            .await
419            .unwrap();
420
421        assert_eq!(docs.len(), 2);
422        assert!(docs.iter().any(|d| d.id == "doc1"));
423        assert!(docs.iter().any(|d| d.id == "doc3"));
424    }
425
426    #[tokio::test]
427    async fn test_count_with_filter() {
428        let store = InMemoryVectorStore::new();
429        store.upsert(create_test_documents()).await.unwrap();
430
431        let filter = Filter::new().skill("kubernetes");
432        let count = store.count(Some(filter)).await.unwrap();
433        assert_eq!(count, 2);
434
435        let filter = Filter::new().skill("git");
436        let count = store.count(Some(filter)).await.unwrap();
437        assert_eq!(count, 1);
438    }
439
440    #[tokio::test]
441    async fn test_health_check() {
442        let store = InMemoryVectorStore::new();
443        store.upsert(create_test_documents()).await.unwrap();
444
445        let status = store.health_check().await.unwrap();
446        assert!(status.healthy);
447        assert_eq!(status.backend, "in_memory");
448        assert_eq!(status.document_count, Some(4));
449    }
450
451    #[tokio::test]
452    async fn test_dimension_validation() {
453        let store = InMemoryVectorStore::with_dimensions(3);
454
455        // Valid dimensions
456        let docs = vec![EmbeddedDocument::new("doc1", vec![1.0, 0.0, 0.0])];
457        assert!(store.upsert(docs).await.is_ok());
458
459        // Invalid dimensions
460        let docs = vec![EmbeddedDocument::new("doc2", vec![1.0, 0.0])];
461        assert!(store.upsert(docs).await.is_err());
462    }
463
464    #[tokio::test]
465    async fn test_euclidean_metric() {
466        let store = InMemoryVectorStore::with_metric(DistanceMetric::Euclidean);
467        store.upsert(create_test_documents()).await.unwrap();
468
469        let results = store
470            .search(vec![1.0, 0.0, 0.0], None, 2)
471            .await
472            .unwrap();
473
474        // doc1 should still be first (closest)
475        assert_eq!(results[0].id, "doc1");
476    }
477
478    #[tokio::test]
479    async fn test_clear() {
480        let store = InMemoryVectorStore::new();
481        store.upsert(create_test_documents()).await.unwrap();
482
483        assert_eq!(store.count(None).await.unwrap(), 4);
484
485        store.clear();
486
487        assert_eq!(store.count(None).await.unwrap(), 0);
488    }
489
490    #[tokio::test]
491    async fn test_backend_name() {
492        let store = InMemoryVectorStore::new();
493        assert_eq!(store.backend_name(), "in_memory");
494    }
495}