umi_memory/storage/
vector.rs

1//! Vector Backend - Embedding Storage and Similarity Search (ADR-006)
2//!
3//! TigerStyle: Trait-based abstraction, simulation-first testing.
4//!
5//! # Overview
6//!
7//! Stores vector embeddings and enables similarity search.
8//! Used for semantic search in dual retrieval strategy.
9//!
10//! # Architecture
11//!
12//! ```text
13//! ┌─────────────────────────────────────────────────────────────┐
14//! │                    VectorBackend Trait                       │
15//! └─────────────────────────────────────────────────────────────┘
16//!          ↑                              ↑
17//!          │                              │
18//! ┌────────┴────────┐           ┌────────┴────────┐
19//! │SimVectorBackend │           │ QdrantBackend   │
20//! │   (testing)     │           │  (production)   │
21//! └─────────────────┘           └─────────────────┘
22//! ```
23
24use std::collections::HashMap;
25use std::sync::Arc;
26
27use async_trait::async_trait;
28
29use crate::constants::EMBEDDING_DIMENSIONS_COUNT;
30use crate::dst::{DeterministicRng, FaultInjector};
31use crate::storage::{StorageError, StorageResult};
32
33// =============================================================================
34// Vector Backend Trait
35// =============================================================================
36
37/// Result of a similarity search.
38#[derive(Debug, Clone)]
39pub struct VectorSearchResult {
40    /// Entity ID
41    pub id: String,
42    /// Similarity score (0.0 to 1.0, higher = more similar)
43    pub score: f32,
44}
45
46/// Trait for vector embedding storage backends.
47#[async_trait]
48pub trait VectorBackend: Send + Sync {
49    /// Store an embedding for an entity.
50    ///
51    /// # Arguments
52    /// * `id` - Entity ID to associate with the embedding
53    /// * `embedding` - Vector embedding (must match EMBEDDING_DIMENSIONS_COUNT)
54    async fn store(&self, id: &str, embedding: &[f32]) -> StorageResult<()>;
55
56    /// Search for similar embeddings.
57    ///
58    /// # Arguments
59    /// * `embedding` - Query embedding
60    /// * `limit` - Maximum number of results
61    ///
62    /// # Returns
63    /// Results sorted by similarity (highest first)
64    async fn search(
65        &self,
66        embedding: &[f32],
67        limit: usize,
68    ) -> StorageResult<Vec<VectorSearchResult>>;
69
70    /// Delete an embedding.
71    async fn delete(&self, id: &str) -> StorageResult<()>;
72
73    /// Check if an embedding exists.
74    async fn exists(&self, id: &str) -> StorageResult<bool>;
75
76    /// Get the embedding for an entity.
77    async fn get(&self, id: &str) -> StorageResult<Option<Vec<f32>>>;
78
79    /// Get the number of stored embeddings.
80    async fn count(&self) -> StorageResult<usize>;
81}
82
83// =============================================================================
84// Simulated Vector Backend (for DST)
85// =============================================================================
86
87/// In-memory vector backend for deterministic simulation testing.
88///
89/// Features:
90/// - Deterministic similarity computation
91/// - Fault injection support
92/// - No external dependencies
93#[derive(Clone)]
94pub struct SimVectorBackend {
95    /// Stored embeddings
96    embeddings: Arc<std::sync::RwLock<HashMap<String, Vec<f32>>>>,
97    /// Fault injector for testing error paths
98    fault_injector: Option<Arc<FaultInjector>>,
99    /// RNG for deterministic behavior
100    _rng: Arc<std::sync::RwLock<DeterministicRng>>,
101}
102
103impl SimVectorBackend {
104    /// Create a new simulated vector backend.
105    #[must_use]
106    pub fn new(seed: u64) -> Self {
107        Self {
108            embeddings: Arc::new(std::sync::RwLock::new(HashMap::new())),
109            fault_injector: None,
110            _rng: Arc::new(std::sync::RwLock::new(DeterministicRng::new(seed))),
111        }
112    }
113
114    /// Create with fault injection enabled.
115    #[must_use]
116    pub fn with_faults(seed: u64, fault_injector: Arc<FaultInjector>) -> Self {
117        Self {
118            embeddings: Arc::new(std::sync::RwLock::new(HashMap::new())),
119            fault_injector: Some(fault_injector),
120            _rng: Arc::new(std::sync::RwLock::new(DeterministicRng::new(seed))),
121        }
122    }
123
124    /// Compute cosine similarity between two vectors.
125    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
126        // Preconditions
127        assert_eq!(a.len(), b.len(), "vectors must have same length");
128        assert!(!a.is_empty(), "vectors must not be empty");
129
130        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
131        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
132        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
133
134        if norm_a == 0.0 || norm_b == 0.0 {
135            return 0.0;
136        }
137
138        let similarity = dot / (norm_a * norm_b);
139
140        // Postcondition: similarity is in [-1, 1], normalize to [0, 1]
141        (similarity + 1.0) / 2.0
142    }
143
144    /// Check if a fault should be injected.
145    fn should_inject_fault(&self, operation: &str) -> bool {
146        if let Some(ref injector) = self.fault_injector {
147            injector.should_inject(operation).is_some()
148        } else {
149            false
150        }
151    }
152}
153
154#[async_trait]
155impl VectorBackend for SimVectorBackend {
156    async fn store(&self, id: &str, embedding: &[f32]) -> StorageResult<()> {
157        // Preconditions
158        assert!(!id.is_empty(), "id must not be empty");
159        assert_eq!(
160            embedding.len(),
161            EMBEDDING_DIMENSIONS_COUNT,
162            "embedding must have {} dimensions, got {}",
163            EMBEDDING_DIMENSIONS_COUNT,
164            embedding.len()
165        );
166
167        // Fault injection
168        if self.should_inject_fault("vector_store_fail") {
169            return Err(StorageError::write("Injected fault: vector store failed"));
170        }
171
172        let mut embeddings = self.embeddings.write().unwrap();
173        embeddings.insert(id.to_string(), embedding.to_vec());
174
175        // Postcondition
176        assert!(embeddings.contains_key(id), "embedding must be stored");
177        Ok(())
178    }
179
180    async fn search(
181        &self,
182        embedding: &[f32],
183        limit: usize,
184    ) -> StorageResult<Vec<VectorSearchResult>> {
185        // Preconditions
186        assert_eq!(
187            embedding.len(),
188            EMBEDDING_DIMENSIONS_COUNT,
189            "query embedding must have {} dimensions, got {}",
190            EMBEDDING_DIMENSIONS_COUNT,
191            embedding.len()
192        );
193        assert!(limit > 0, "limit must be positive");
194
195        // Fault injection
196        if self.should_inject_fault("vector_search_timeout") {
197            return Err(StorageError::timeout(5000)); // 5 second timeout
198        }
199        if self.should_inject_fault("vector_search_fail") {
200            return Err(StorageError::read("Injected fault: vector search failed"));
201        }
202
203        let embeddings = self.embeddings.read().unwrap();
204
205        // Compute similarities
206        let mut results: Vec<VectorSearchResult> = embeddings
207            .iter()
208            .map(|(id, stored)| VectorSearchResult {
209                id: id.clone(),
210                score: Self::cosine_similarity(embedding, stored),
211            })
212            .collect();
213
214        // Sort by score descending
215        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
216
217        // Limit results
218        results.truncate(limit);
219
220        // Postcondition
221        assert!(results.len() <= limit, "results must not exceed limit");
222        Ok(results)
223    }
224
225    async fn delete(&self, id: &str) -> StorageResult<()> {
226        // Precondition
227        assert!(!id.is_empty(), "id must not be empty");
228
229        // Fault injection
230        if self.should_inject_fault("vector_delete") {
231            return Err(StorageError::write("Injected fault: vector delete failed"));
232        }
233
234        let mut embeddings = self.embeddings.write().unwrap();
235        embeddings.remove(id);
236
237        // Postcondition
238        assert!(!embeddings.contains_key(id), "embedding must be deleted");
239        Ok(())
240    }
241
242    async fn exists(&self, id: &str) -> StorageResult<bool> {
243        // Precondition
244        assert!(!id.is_empty(), "id must not be empty");
245
246        // Fault injection
247        if self.should_inject_fault("vector_exists") {
248            return Err(StorageError::read(
249                "Injected fault: vector exists check failed",
250            ));
251        }
252
253        let embeddings = self.embeddings.read().unwrap();
254        Ok(embeddings.contains_key(id))
255    }
256
257    async fn get(&self, id: &str) -> StorageResult<Option<Vec<f32>>> {
258        // Precondition
259        assert!(!id.is_empty(), "id must not be empty");
260
261        // Fault injection
262        if self.should_inject_fault("vector_get") {
263            return Err(StorageError::read("Injected fault: vector get failed"));
264        }
265
266        let embeddings = self.embeddings.read().unwrap();
267        Ok(embeddings.get(id).cloned())
268    }
269
270    async fn count(&self) -> StorageResult<usize> {
271        // Fault injection
272        if self.should_inject_fault("vector_count") {
273            return Err(StorageError::read("Injected fault: vector count failed"));
274        }
275
276        let embeddings = self.embeddings.read().unwrap();
277        Ok(embeddings.len())
278    }
279}
280
281// =============================================================================
282// Tests
283// =============================================================================
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    /// Create a test embedding with specified seed for reproducibility.
290    fn make_embedding(seed: u64) -> Vec<f32> {
291        let mut rng = DeterministicRng::new(seed);
292        (0..EMBEDDING_DIMENSIONS_COUNT)
293            .map(|_| (rng.next_float() * 2.0 - 1.0) as f32) // Values in [-1, 1]
294            .collect()
295    }
296
297    // =========================================================================
298    // SimVectorBackend Tests
299    // =========================================================================
300
301    #[tokio::test]
302    async fn test_sim_vector_store_and_get() {
303        let backend = SimVectorBackend::new(42);
304        let embedding = make_embedding(1);
305
306        // Store
307        backend.store("entity-1", &embedding).await.unwrap();
308
309        // Get
310        let retrieved = backend.get("entity-1").await.unwrap();
311        assert!(retrieved.is_some());
312        assert_eq!(retrieved.unwrap(), embedding);
313    }
314
315    #[tokio::test]
316    async fn test_sim_vector_exists() {
317        let backend = SimVectorBackend::new(42);
318        let embedding = make_embedding(1);
319
320        // Not exists initially
321        assert!(!backend.exists("entity-1").await.unwrap());
322
323        // Store
324        backend.store("entity-1", &embedding).await.unwrap();
325
326        // Now exists
327        assert!(backend.exists("entity-1").await.unwrap());
328    }
329
330    #[tokio::test]
331    async fn test_sim_vector_delete() {
332        let backend = SimVectorBackend::new(42);
333        let embedding = make_embedding(1);
334
335        // Store
336        backend.store("entity-1", &embedding).await.unwrap();
337        assert!(backend.exists("entity-1").await.unwrap());
338
339        // Delete
340        backend.delete("entity-1").await.unwrap();
341        assert!(!backend.exists("entity-1").await.unwrap());
342    }
343
344    #[tokio::test]
345    async fn test_sim_vector_count() {
346        let backend = SimVectorBackend::new(42);
347
348        assert_eq!(backend.count().await.unwrap(), 0);
349
350        backend.store("e1", &make_embedding(1)).await.unwrap();
351        assert_eq!(backend.count().await.unwrap(), 1);
352
353        backend.store("e2", &make_embedding(2)).await.unwrap();
354        assert_eq!(backend.count().await.unwrap(), 2);
355
356        backend.delete("e1").await.unwrap();
357        assert_eq!(backend.count().await.unwrap(), 1);
358    }
359
360    #[tokio::test]
361    async fn test_sim_vector_search_finds_similar() {
362        let backend = SimVectorBackend::new(42);
363
364        // Store some embeddings
365        let base = make_embedding(100);
366        backend.store("base", &base).await.unwrap();
367
368        // Store a similar embedding (slightly modified)
369        let mut similar = base.clone();
370        similar[0] += 0.01;
371        similar[1] -= 0.01;
372        backend.store("similar", &similar).await.unwrap();
373
374        // Store a different embedding
375        let different = make_embedding(999);
376        backend.store("different", &different).await.unwrap();
377
378        // Search with base embedding
379        let results = backend.search(&base, 3).await.unwrap();
380
381        // Should find base first (exact match)
382        assert_eq!(results.len(), 3);
383        assert_eq!(results[0].id, "base");
384        assert!((results[0].score - 1.0).abs() < 0.001); // Exact match = 1.0
385
386        // Similar should be second
387        assert_eq!(results[1].id, "similar");
388        assert!(results[1].score > 0.99); // Very similar
389    }
390
391    #[tokio::test]
392    async fn test_sim_vector_search_respects_limit() {
393        let backend = SimVectorBackend::new(42);
394
395        // Store 10 embeddings
396        for i in 0..10 {
397            backend
398                .store(&format!("e{i}"), &make_embedding(i))
399                .await
400                .unwrap();
401        }
402
403        // Search with limit 3
404        let results = backend.search(&make_embedding(0), 3).await.unwrap();
405        assert_eq!(results.len(), 3);
406    }
407
408    #[tokio::test]
409    async fn test_sim_vector_search_sorted_by_score() {
410        let backend = SimVectorBackend::new(42);
411
412        // Store multiple embeddings
413        for i in 0..5 {
414            backend
415                .store(&format!("e{i}"), &make_embedding(i))
416                .await
417                .unwrap();
418        }
419
420        // Search
421        let results = backend.search(&make_embedding(0), 5).await.unwrap();
422
423        // Verify sorted by score descending
424        for i in 1..results.len() {
425            assert!(
426                results[i - 1].score >= results[i].score,
427                "results must be sorted by score descending"
428            );
429        }
430    }
431
432    // =========================================================================
433    // Cosine Similarity Tests
434    // =========================================================================
435
436    #[test]
437    fn test_cosine_similarity_identical() {
438        let v = vec![1.0, 0.0, 0.0];
439        let similarity = SimVectorBackend::cosine_similarity(&v, &v);
440        // Normalized to [0, 1], so identical = 1.0
441        assert!((similarity - 1.0).abs() < 0.001);
442    }
443
444    #[test]
445    fn test_cosine_similarity_opposite() {
446        let v1 = vec![1.0, 0.0, 0.0];
447        let v2 = vec![-1.0, 0.0, 0.0];
448        let similarity = SimVectorBackend::cosine_similarity(&v1, &v2);
449        // Opposite vectors = 0.0 after normalization to [0, 1]
450        assert!(similarity.abs() < 0.001);
451    }
452
453    #[test]
454    fn test_cosine_similarity_orthogonal() {
455        let v1 = vec![1.0, 0.0, 0.0];
456        let v2 = vec![0.0, 1.0, 0.0];
457        let similarity = SimVectorBackend::cosine_similarity(&v1, &v2);
458        // Orthogonal = 0.5 after normalization to [0, 1]
459        assert!((similarity - 0.5).abs() < 0.001);
460    }
461
462    // =========================================================================
463    // Precondition Tests
464    // =========================================================================
465
466    #[tokio::test]
467    #[should_panic(expected = "id must not be empty")]
468    async fn test_sim_vector_store_empty_id() {
469        let backend = SimVectorBackend::new(42);
470        let _ = backend.store("", &make_embedding(1)).await;
471    }
472
473    #[tokio::test]
474    #[should_panic(expected = "embedding must have")]
475    async fn test_sim_vector_store_wrong_dimensions() {
476        let backend = SimVectorBackend::new(42);
477        let wrong_size = vec![1.0, 2.0, 3.0]; // Wrong dimension
478        let _ = backend.store("entity-1", &wrong_size).await;
479    }
480
481    #[tokio::test]
482    #[should_panic(expected = "limit must be positive")]
483    async fn test_sim_vector_search_zero_limit() {
484        let backend = SimVectorBackend::new(42);
485        let _ = backend.search(&make_embedding(1), 0).await;
486    }
487
488    // =========================================================================
489    // Determinism Tests (DST)
490    // =========================================================================
491
492    #[tokio::test]
493    async fn test_sim_vector_deterministic() {
494        // Same seed should produce same results
495        async fn run_operations(seed: u64) -> Vec<VectorSearchResult> {
496            let backend = SimVectorBackend::new(seed);
497
498            backend.store("e1", &make_embedding(1)).await.unwrap();
499            backend.store("e2", &make_embedding(2)).await.unwrap();
500            backend.store("e3", &make_embedding(3)).await.unwrap();
501
502            backend.search(&make_embedding(1), 3).await.unwrap()
503        }
504
505        let results1 = run_operations(42).await;
506        let results2 = run_operations(42).await;
507
508        assert_eq!(results1.len(), results2.len());
509        for (r1, r2) in results1.iter().zip(results2.iter()) {
510            assert_eq!(r1.id, r2.id);
511            assert!((r1.score - r2.score).abs() < f32::EPSILON);
512        }
513    }
514}