1use std::collections::HashMap;
25use std::sync::Arc;
26
27use async_trait::async_trait;
28
29use crate::constants::EMBEDDING_DIMENSIONS_COUNT;
30use crate::dst::{DeterministicRng, FaultInjector};
31use crate::storage::{StorageError, StorageResult};
32
33#[derive(Debug, Clone)]
39pub struct VectorSearchResult {
40 pub id: String,
42 pub score: f32,
44}
45
46#[async_trait]
48pub trait VectorBackend: Send + Sync {
49 async fn store(&self, id: &str, embedding: &[f32]) -> StorageResult<()>;
55
56 async fn search(
65 &self,
66 embedding: &[f32],
67 limit: usize,
68 ) -> StorageResult<Vec<VectorSearchResult>>;
69
70 async fn delete(&self, id: &str) -> StorageResult<()>;
72
73 async fn exists(&self, id: &str) -> StorageResult<bool>;
75
76 async fn get(&self, id: &str) -> StorageResult<Option<Vec<f32>>>;
78
79 async fn count(&self) -> StorageResult<usize>;
81}
82
83#[derive(Clone)]
94pub struct SimVectorBackend {
95 embeddings: Arc<std::sync::RwLock<HashMap<String, Vec<f32>>>>,
97 fault_injector: Option<Arc<FaultInjector>>,
99 _rng: Arc<std::sync::RwLock<DeterministicRng>>,
101}
102
103impl SimVectorBackend {
104 #[must_use]
106 pub fn new(seed: u64) -> Self {
107 Self {
108 embeddings: Arc::new(std::sync::RwLock::new(HashMap::new())),
109 fault_injector: None,
110 _rng: Arc::new(std::sync::RwLock::new(DeterministicRng::new(seed))),
111 }
112 }
113
114 #[must_use]
116 pub fn with_faults(seed: u64, fault_injector: Arc<FaultInjector>) -> Self {
117 Self {
118 embeddings: Arc::new(std::sync::RwLock::new(HashMap::new())),
119 fault_injector: Some(fault_injector),
120 _rng: Arc::new(std::sync::RwLock::new(DeterministicRng::new(seed))),
121 }
122 }
123
124 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
126 assert_eq!(a.len(), b.len(), "vectors must have same length");
128 assert!(!a.is_empty(), "vectors must not be empty");
129
130 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
131 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
132 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
133
134 if norm_a == 0.0 || norm_b == 0.0 {
135 return 0.0;
136 }
137
138 let similarity = dot / (norm_a * norm_b);
139
140 (similarity + 1.0) / 2.0
142 }
143
144 fn should_inject_fault(&self, operation: &str) -> bool {
146 if let Some(ref injector) = self.fault_injector {
147 injector.should_inject(operation).is_some()
148 } else {
149 false
150 }
151 }
152}
153
154#[async_trait]
155impl VectorBackend for SimVectorBackend {
156 async fn store(&self, id: &str, embedding: &[f32]) -> StorageResult<()> {
157 assert!(!id.is_empty(), "id must not be empty");
159 assert_eq!(
160 embedding.len(),
161 EMBEDDING_DIMENSIONS_COUNT,
162 "embedding must have {} dimensions, got {}",
163 EMBEDDING_DIMENSIONS_COUNT,
164 embedding.len()
165 );
166
167 if self.should_inject_fault("vector_store_fail") {
169 return Err(StorageError::write("Injected fault: vector store failed"));
170 }
171
172 let mut embeddings = self.embeddings.write().unwrap();
173 embeddings.insert(id.to_string(), embedding.to_vec());
174
175 assert!(embeddings.contains_key(id), "embedding must be stored");
177 Ok(())
178 }
179
180 async fn search(
181 &self,
182 embedding: &[f32],
183 limit: usize,
184 ) -> StorageResult<Vec<VectorSearchResult>> {
185 assert_eq!(
187 embedding.len(),
188 EMBEDDING_DIMENSIONS_COUNT,
189 "query embedding must have {} dimensions, got {}",
190 EMBEDDING_DIMENSIONS_COUNT,
191 embedding.len()
192 );
193 assert!(limit > 0, "limit must be positive");
194
195 if self.should_inject_fault("vector_search_timeout") {
197 return Err(StorageError::timeout(5000)); }
199 if self.should_inject_fault("vector_search_fail") {
200 return Err(StorageError::read("Injected fault: vector search failed"));
201 }
202
203 let embeddings = self.embeddings.read().unwrap();
204
205 let mut results: Vec<VectorSearchResult> = embeddings
207 .iter()
208 .map(|(id, stored)| VectorSearchResult {
209 id: id.clone(),
210 score: Self::cosine_similarity(embedding, stored),
211 })
212 .collect();
213
214 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
216
217 results.truncate(limit);
219
220 assert!(results.len() <= limit, "results must not exceed limit");
222 Ok(results)
223 }
224
225 async fn delete(&self, id: &str) -> StorageResult<()> {
226 assert!(!id.is_empty(), "id must not be empty");
228
229 if self.should_inject_fault("vector_delete") {
231 return Err(StorageError::write("Injected fault: vector delete failed"));
232 }
233
234 let mut embeddings = self.embeddings.write().unwrap();
235 embeddings.remove(id);
236
237 assert!(!embeddings.contains_key(id), "embedding must be deleted");
239 Ok(())
240 }
241
242 async fn exists(&self, id: &str) -> StorageResult<bool> {
243 assert!(!id.is_empty(), "id must not be empty");
245
246 if self.should_inject_fault("vector_exists") {
248 return Err(StorageError::read(
249 "Injected fault: vector exists check failed",
250 ));
251 }
252
253 let embeddings = self.embeddings.read().unwrap();
254 Ok(embeddings.contains_key(id))
255 }
256
257 async fn get(&self, id: &str) -> StorageResult<Option<Vec<f32>>> {
258 assert!(!id.is_empty(), "id must not be empty");
260
261 if self.should_inject_fault("vector_get") {
263 return Err(StorageError::read("Injected fault: vector get failed"));
264 }
265
266 let embeddings = self.embeddings.read().unwrap();
267 Ok(embeddings.get(id).cloned())
268 }
269
270 async fn count(&self) -> StorageResult<usize> {
271 if self.should_inject_fault("vector_count") {
273 return Err(StorageError::read("Injected fault: vector count failed"));
274 }
275
276 let embeddings = self.embeddings.read().unwrap();
277 Ok(embeddings.len())
278 }
279}
280
281#[cfg(test)]
286mod tests {
287 use super::*;
288
289 fn make_embedding(seed: u64) -> Vec<f32> {
291 let mut rng = DeterministicRng::new(seed);
292 (0..EMBEDDING_DIMENSIONS_COUNT)
293 .map(|_| (rng.next_float() * 2.0 - 1.0) as f32) .collect()
295 }
296
297 #[tokio::test]
302 async fn test_sim_vector_store_and_get() {
303 let backend = SimVectorBackend::new(42);
304 let embedding = make_embedding(1);
305
306 backend.store("entity-1", &embedding).await.unwrap();
308
309 let retrieved = backend.get("entity-1").await.unwrap();
311 assert!(retrieved.is_some());
312 assert_eq!(retrieved.unwrap(), embedding);
313 }
314
315 #[tokio::test]
316 async fn test_sim_vector_exists() {
317 let backend = SimVectorBackend::new(42);
318 let embedding = make_embedding(1);
319
320 assert!(!backend.exists("entity-1").await.unwrap());
322
323 backend.store("entity-1", &embedding).await.unwrap();
325
326 assert!(backend.exists("entity-1").await.unwrap());
328 }
329
330 #[tokio::test]
331 async fn test_sim_vector_delete() {
332 let backend = SimVectorBackend::new(42);
333 let embedding = make_embedding(1);
334
335 backend.store("entity-1", &embedding).await.unwrap();
337 assert!(backend.exists("entity-1").await.unwrap());
338
339 backend.delete("entity-1").await.unwrap();
341 assert!(!backend.exists("entity-1").await.unwrap());
342 }
343
344 #[tokio::test]
345 async fn test_sim_vector_count() {
346 let backend = SimVectorBackend::new(42);
347
348 assert_eq!(backend.count().await.unwrap(), 0);
349
350 backend.store("e1", &make_embedding(1)).await.unwrap();
351 assert_eq!(backend.count().await.unwrap(), 1);
352
353 backend.store("e2", &make_embedding(2)).await.unwrap();
354 assert_eq!(backend.count().await.unwrap(), 2);
355
356 backend.delete("e1").await.unwrap();
357 assert_eq!(backend.count().await.unwrap(), 1);
358 }
359
360 #[tokio::test]
361 async fn test_sim_vector_search_finds_similar() {
362 let backend = SimVectorBackend::new(42);
363
364 let base = make_embedding(100);
366 backend.store("base", &base).await.unwrap();
367
368 let mut similar = base.clone();
370 similar[0] += 0.01;
371 similar[1] -= 0.01;
372 backend.store("similar", &similar).await.unwrap();
373
374 let different = make_embedding(999);
376 backend.store("different", &different).await.unwrap();
377
378 let results = backend.search(&base, 3).await.unwrap();
380
381 assert_eq!(results.len(), 3);
383 assert_eq!(results[0].id, "base");
384 assert!((results[0].score - 1.0).abs() < 0.001); assert_eq!(results[1].id, "similar");
388 assert!(results[1].score > 0.99); }
390
391 #[tokio::test]
392 async fn test_sim_vector_search_respects_limit() {
393 let backend = SimVectorBackend::new(42);
394
395 for i in 0..10 {
397 backend
398 .store(&format!("e{i}"), &make_embedding(i))
399 .await
400 .unwrap();
401 }
402
403 let results = backend.search(&make_embedding(0), 3).await.unwrap();
405 assert_eq!(results.len(), 3);
406 }
407
408 #[tokio::test]
409 async fn test_sim_vector_search_sorted_by_score() {
410 let backend = SimVectorBackend::new(42);
411
412 for i in 0..5 {
414 backend
415 .store(&format!("e{i}"), &make_embedding(i))
416 .await
417 .unwrap();
418 }
419
420 let results = backend.search(&make_embedding(0), 5).await.unwrap();
422
423 for i in 1..results.len() {
425 assert!(
426 results[i - 1].score >= results[i].score,
427 "results must be sorted by score descending"
428 );
429 }
430 }
431
432 #[test]
437 fn test_cosine_similarity_identical() {
438 let v = vec![1.0, 0.0, 0.0];
439 let similarity = SimVectorBackend::cosine_similarity(&v, &v);
440 assert!((similarity - 1.0).abs() < 0.001);
442 }
443
444 #[test]
445 fn test_cosine_similarity_opposite() {
446 let v1 = vec![1.0, 0.0, 0.0];
447 let v2 = vec![-1.0, 0.0, 0.0];
448 let similarity = SimVectorBackend::cosine_similarity(&v1, &v2);
449 assert!(similarity.abs() < 0.001);
451 }
452
453 #[test]
454 fn test_cosine_similarity_orthogonal() {
455 let v1 = vec![1.0, 0.0, 0.0];
456 let v2 = vec![0.0, 1.0, 0.0];
457 let similarity = SimVectorBackend::cosine_similarity(&v1, &v2);
458 assert!((similarity - 0.5).abs() < 0.001);
460 }
461
462 #[tokio::test]
467 #[should_panic(expected = "id must not be empty")]
468 async fn test_sim_vector_store_empty_id() {
469 let backend = SimVectorBackend::new(42);
470 let _ = backend.store("", &make_embedding(1)).await;
471 }
472
473 #[tokio::test]
474 #[should_panic(expected = "embedding must have")]
475 async fn test_sim_vector_store_wrong_dimensions() {
476 let backend = SimVectorBackend::new(42);
477 let wrong_size = vec![1.0, 2.0, 3.0]; let _ = backend.store("entity-1", &wrong_size).await;
479 }
480
481 #[tokio::test]
482 #[should_panic(expected = "limit must be positive")]
483 async fn test_sim_vector_search_zero_limit() {
484 let backend = SimVectorBackend::new(42);
485 let _ = backend.search(&make_embedding(1), 0).await;
486 }
487
488 #[tokio::test]
493 async fn test_sim_vector_deterministic() {
494 async fn run_operations(seed: u64) -> Vec<VectorSearchResult> {
496 let backend = SimVectorBackend::new(seed);
497
498 backend.store("e1", &make_embedding(1)).await.unwrap();
499 backend.store("e2", &make_embedding(2)).await.unwrap();
500 backend.store("e3", &make_embedding(3)).await.unwrap();
501
502 backend.search(&make_embedding(1), 3).await.unwrap()
503 }
504
505 let results1 = run_operations(42).await;
506 let results2 = run_operations(42).await;
507
508 assert_eq!(results1.len(), results2.len());
509 for (r1, r2) in results1.iter().zip(results2.iter()) {
510 assert_eq!(r1.id, r2.id);
511 assert!((r1.score - r2.score).abs() < f32::EPSILON);
512 }
513 }
514}