umi_memory/embedding/
sim.rs1use std::collections::hash_map::DefaultHasher;
37use std::hash::{Hash, Hasher};
38
39use std::sync::Arc;
40
41use async_trait::async_trait;
42
43use super::{EmbeddingError, EmbeddingProvider};
44use crate::constants::EMBEDDING_DIMENSIONS_COUNT;
45use crate::dst::{DeterministicRng, FaultInjector};
46
47#[derive(Clone)]
60pub struct SimEmbeddingProvider {
61 seed: u64,
63 dimensions: usize,
65 fault_injector: Option<Arc<FaultInjector>>,
67}
68
69impl SimEmbeddingProvider {
70 #[must_use]
83 pub fn new(seed: u64) -> Self {
84 Self {
85 seed,
86 dimensions: EMBEDDING_DIMENSIONS_COUNT,
87 fault_injector: None,
88 }
89 }
90
91 #[must_use]
93 pub fn with_seed(seed: u64) -> Self {
94 Self::new(seed)
95 }
96
97 #[must_use]
99 pub fn with_faults(seed: u64, fault_injector: Arc<FaultInjector>) -> Self {
100 Self {
101 seed,
102 dimensions: EMBEDDING_DIMENSIONS_COUNT,
103 fault_injector: Some(fault_injector),
104 }
105 }
106
107 fn should_inject_fault(&self, operation: &str) -> bool {
109 if let Some(ref injector) = self.fault_injector {
110 injector.should_inject(operation).is_some()
111 } else {
112 false
113 }
114 }
115
116 fn hash_text(&self, text: &str) -> u64 {
120 let mut hasher = DefaultHasher::new();
121 self.seed.hash(&mut hasher);
122 text.hash(&mut hasher);
123 hasher.finish()
124 }
125
126 fn generate_embedding(&self, text: &str) -> Vec<f32> {
133 let text_seed = self.hash_text(text);
135 let mut rng = DeterministicRng::new(text_seed);
136
137 let mut embedding: Vec<f32> = (0..self.dimensions)
139 .map(|_| {
140 let val = rng.next_float();
141 (val * 2.0 - 1.0) as f32 })
143 .collect();
144
145 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
147 if norm > 0.0 {
148 for val in &mut embedding {
149 *val /= norm;
150 }
151 }
152
153 debug_assert!(
155 {
156 let check_norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
157 (check_norm - 1.0).abs() < 0.001
158 },
159 "embedding must be normalized to unit vector"
160 );
161 debug_assert_eq!(
162 embedding.len(),
163 self.dimensions,
164 "embedding must have correct dimensions"
165 );
166
167 embedding
168 }
169}
170
171#[async_trait]
172impl EmbeddingProvider for SimEmbeddingProvider {
173 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
174 if text.is_empty() {
176 return Err(EmbeddingError::EmptyInput);
177 }
178
179 if self.should_inject_fault("embedding_timeout") {
181 return Err(EmbeddingError::Timeout);
182 }
183 if self.should_inject_fault("embedding_rate_limit") {
184 return Err(EmbeddingError::rate_limit(Some(60)));
185 }
186 if self.should_inject_fault("embedding_service_unavailable") {
187 return Err(EmbeddingError::service_unavailable("Simulated failure"));
188 }
189
190 Ok(self.generate_embedding(text))
191 }
192
193 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
194 if texts.is_empty() {
196 return Err(EmbeddingError::invalid_request("batch cannot be empty"));
197 }
198
199 if self.should_inject_fault("embedding_timeout") {
201 return Err(EmbeddingError::Timeout);
202 }
203 if self.should_inject_fault("embedding_rate_limit") {
204 return Err(EmbeddingError::rate_limit(Some(60)));
205 }
206 if self.should_inject_fault("embedding_service_unavailable") {
207 return Err(EmbeddingError::service_unavailable("Simulated failure"));
208 }
209
210 let mut embeddings = Vec::with_capacity(texts.len());
212 for text in texts {
213 if text.is_empty() {
214 return Err(EmbeddingError::EmptyInput);
215 }
216 embeddings.push(self.generate_embedding(text));
217 }
218
219 debug_assert_eq!(
221 embeddings.len(),
222 texts.len(),
223 "must return one embedding per input"
224 );
225
226 Ok(embeddings)
227 }
228
229 fn dimensions(&self) -> usize {
230 self.dimensions
231 }
232
233 fn name(&self) -> &'static str {
234 "sim-embedding"
235 }
236
237 fn is_simulation(&self) -> bool {
238 true
239 }
240}
241
242#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[tokio::test]
251 async fn test_sim_embedding_basic() {
252 let provider = SimEmbeddingProvider::new(42);
253 let embedding = provider.embed("Alice works at Acme").await.unwrap();
254
255 assert_eq!(embedding.len(), EMBEDDING_DIMENSIONS_COUNT);
256 }
257
258 #[tokio::test]
259 async fn test_sim_embedding_deterministic() {
260 let provider = SimEmbeddingProvider::new(42);
261
262 let emb1 = provider.embed("Alice works at Acme").await.unwrap();
263 let emb2 = provider.embed("Alice works at Acme").await.unwrap();
264
265 assert_eq!(emb1, emb2);
267 }
268
269 #[tokio::test]
270 async fn test_sim_embedding_different_text() {
271 let provider = SimEmbeddingProvider::new(42);
272
273 let emb1 = provider.embed("Alice works at Acme").await.unwrap();
274 let emb2 = provider.embed("Bob works at TechCo").await.unwrap();
275
276 assert_ne!(emb1, emb2);
278 }
279
280 #[tokio::test]
281 async fn test_sim_embedding_different_seed() {
282 let provider1 = SimEmbeddingProvider::new(42);
283 let provider2 = SimEmbeddingProvider::new(99);
284
285 let emb1 = provider1.embed("Alice works at Acme").await.unwrap();
286 let emb2 = provider2.embed("Alice works at Acme").await.unwrap();
287
288 assert_ne!(emb1, emb2);
290 }
291
292 #[tokio::test]
293 async fn test_sim_embedding_normalized() {
294 let provider = SimEmbeddingProvider::new(42);
295 let embedding = provider.embed("Alice works at Acme").await.unwrap();
296
297 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
299 assert!((norm - 1.0).abs() < 0.001, "embedding must be normalized");
300 }
301
302 #[tokio::test]
303 async fn test_sim_embedding_empty_text() {
304 let provider = SimEmbeddingProvider::new(42);
305 let result = provider.embed("").await;
306
307 assert!(matches!(result, Err(EmbeddingError::EmptyInput)));
308 }
309
310 #[tokio::test]
311 async fn test_sim_embedding_batch() {
312 let provider = SimEmbeddingProvider::new(42);
313 let texts = vec!["Alice works at Acme", "Bob works at TechCo"];
314
315 let embeddings = provider.embed_batch(&texts).await.unwrap();
316
317 assert_eq!(embeddings.len(), 2);
318 assert_eq!(embeddings[0].len(), EMBEDDING_DIMENSIONS_COUNT);
319 assert_eq!(embeddings[1].len(), EMBEDDING_DIMENSIONS_COUNT);
320
321 let single1 = provider.embed(texts[0]).await.unwrap();
323 let single2 = provider.embed(texts[1]).await.unwrap();
324
325 assert_eq!(embeddings[0], single1);
326 assert_eq!(embeddings[1], single2);
327 }
328
329 #[tokio::test]
330 async fn test_sim_embedding_batch_empty() {
331 let provider = SimEmbeddingProvider::new(42);
332 let texts: Vec<&str> = vec![];
333
334 let result = provider.embed_batch(&texts).await;
335 assert!(result.is_err());
336 }
337
338 #[tokio::test]
339 async fn test_sim_embedding_batch_with_empty_text() {
340 let provider = SimEmbeddingProvider::new(42);
341 let texts = vec!["Alice", ""];
342
343 let result = provider.embed_batch(&texts).await;
344 assert!(matches!(result, Err(EmbeddingError::EmptyInput)));
345 }
346
347 #[tokio::test]
348 async fn test_sim_embedding_provider_traits() {
349 let provider = SimEmbeddingProvider::new(42);
350
351 assert_eq!(provider.dimensions(), EMBEDDING_DIMENSIONS_COUNT);
352 assert_eq!(provider.name(), "sim-embedding");
353 assert!(provider.is_simulation());
354 }
355
356 #[tokio::test]
361 async fn test_determinism_same_seed_same_results() {
362 async fn run_with_seed(seed: u64) -> Vec<f32> {
363 let provider = SimEmbeddingProvider::new(seed);
364 provider.embed("test text").await.unwrap()
365 }
366
367 let result1 = run_with_seed(42).await;
368 let result2 = run_with_seed(42).await;
369
370 assert_eq!(result1, result2, "same seed must produce same results");
371 }
372
373 #[tokio::test]
374 async fn test_determinism_different_seed_different_results() {
375 let provider1 = SimEmbeddingProvider::new(42);
376 let provider2 = SimEmbeddingProvider::new(43);
377
378 let result1 = provider1.embed("test text").await.unwrap();
379 let result2 = provider2.embed("test text").await.unwrap();
380
381 assert_ne!(
382 result1, result2,
383 "different seeds must produce different results"
384 );
385 }
386
387 #[tokio::test]
388 async fn test_batch_determinism() {
389 let provider = SimEmbeddingProvider::new(42);
390 let texts = vec!["text1", "text2", "text3"];
391
392 let batch1 = provider.embed_batch(&texts).await.unwrap();
393 let batch2 = provider.embed_batch(&texts).await.unwrap();
394
395 assert_eq!(batch1, batch2, "batch must be deterministic");
396 }
397
398 #[tokio::test]
403 async fn test_all_embeddings_normalized() {
404 let provider = SimEmbeddingProvider::new(42);
405 let texts = vec![
406 "short",
407 "longer text here",
408 "even longer text with more words to test different lengths",
409 ];
410
411 for text in texts {
412 let embedding = provider.embed(text).await.unwrap();
413 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
414 assert!(
415 (norm - 1.0).abs() < 0.001,
416 "embedding for '{}' must be normalized, got norm {}",
417 text,
418 norm
419 );
420 }
421 }
422
423 #[test]
428 fn test_hash_text_deterministic() {
429 let provider = SimEmbeddingProvider::new(42);
430
431 let hash1 = provider.hash_text("test");
432 let hash2 = provider.hash_text("test");
433
434 assert_eq!(hash1, hash2, "hash must be deterministic");
435 }
436
437 #[test]
438 fn test_hash_text_different_text() {
439 let provider = SimEmbeddingProvider::new(42);
440
441 let hash1 = provider.hash_text("test1");
442 let hash2 = provider.hash_text("test2");
443
444 assert_ne!(hash1, hash2, "different text must produce different hashes");
445 }
446
447 #[test]
448 fn test_hash_text_different_seed() {
449 let provider1 = SimEmbeddingProvider::new(42);
450 let provider2 = SimEmbeddingProvider::new(99);
451
452 let hash1 = provider1.hash_text("test");
453 let hash2 = provider2.hash_text("test");
454
455 assert_ne!(hash1, hash2, "different seed must produce different hashes");
456 }
457}