umi_memory/embedding/
sim.rs

1//! Simulated Embedding Provider for Deterministic Testing
2//!
3//! TigerStyle: Deterministic, reproducible embeddings for DST.
4//!
5//! # Overview
6//!
7//! `SimEmbeddingProvider` generates embeddings deterministically:
8//! - Same text + same seed = same embedding (always)
9//! - No external API calls
10//! - Perfect for testing and reproducibility
11//!
12//! # Algorithm
13//!
14//! 1. Hash text + seed to get base seed
15//! 2. Use `DeterministicRng` to generate random floats in [-1, 1]
16//! 3. Normalize to unit vector (L2 norm = 1)
17//! 4. Return consistent 1536-dimensional embedding
18//!
19//! # Example
20//!
21//! ```rust
22//! use umi_memory::embedding::{EmbeddingProvider, SimEmbeddingProvider};
23//!
24//! #[tokio::main]
25//! async fn main() {
26//!     let provider = SimEmbeddingProvider::with_seed(42);
27//!
28//!     let emb1 = provider.embed("Alice works at Acme").await.unwrap();
29//!     let emb2 = provider.embed("Alice works at Acme").await.unwrap();
30//!
31//!     // Same text = same embedding
32//!     assert_eq!(emb1, emb2);
33//! }
34//! ```
35
36use std::collections::hash_map::DefaultHasher;
37use std::hash::{Hash, Hasher};
38
39use std::sync::Arc;
40
41use async_trait::async_trait;
42
43use super::{EmbeddingError, EmbeddingProvider};
44use crate::constants::EMBEDDING_DIMENSIONS_COUNT;
45use crate::dst::{DeterministicRng, FaultInjector};
46
47// =============================================================================
48// SimEmbeddingProvider
49// =============================================================================
50
51/// In-memory embedding provider for deterministic simulation testing.
52///
53/// Features:
54/// - Deterministic: same text + same seed = same embedding
55/// - No external dependencies
56/// - Fast (no network calls)
57/// - Normalized embeddings (unit vectors)
58/// - Fault injection support for DST
59#[derive(Clone)]
60pub struct SimEmbeddingProvider {
61    /// Base seed for RNG
62    seed: u64,
63    /// Embedding dimensions
64    dimensions: usize,
65    /// Fault injector (optional for DST)
66    fault_injector: Option<Arc<FaultInjector>>,
67}
68
69impl SimEmbeddingProvider {
70    /// Create a new simulated embedding provider with the given seed.
71    ///
72    /// # Arguments
73    /// * `seed` - Base seed for deterministic generation
74    ///
75    /// # Example
76    ///
77    /// ```rust
78    /// use umi_memory::embedding::SimEmbeddingProvider;
79    ///
80    /// let provider = SimEmbeddingProvider::new(42);
81    /// ```
82    #[must_use]
83    pub fn new(seed: u64) -> Self {
84        Self {
85            seed,
86            dimensions: EMBEDDING_DIMENSIONS_COUNT,
87            fault_injector: None,
88        }
89    }
90
91    /// Create with explicit seed (alias for `new`).
92    #[must_use]
93    pub fn with_seed(seed: u64) -> Self {
94        Self::new(seed)
95    }
96
97    /// Create with fault injection enabled.
98    #[must_use]
99    pub fn with_faults(seed: u64, fault_injector: Arc<FaultInjector>) -> Self {
100        Self {
101            seed,
102            dimensions: EMBEDDING_DIMENSIONS_COUNT,
103            fault_injector: Some(fault_injector),
104        }
105    }
106
107    /// Check if a fault should be injected.
108    fn should_inject_fault(&self, operation: &str) -> bool {
109        if let Some(ref injector) = self.fault_injector {
110            injector.should_inject(operation).is_some()
111        } else {
112            false
113        }
114    }
115
116    /// Hash text to get a deterministic seed.
117    ///
118    /// Combines the base seed with text hash for consistent results.
119    fn hash_text(&self, text: &str) -> u64 {
120        let mut hasher = DefaultHasher::new();
121        self.seed.hash(&mut hasher);
122        text.hash(&mut hasher);
123        hasher.finish()
124    }
125
126    /// Generate a deterministic embedding for text.
127    ///
128    /// Algorithm:
129    /// 1. Hash text + seed
130    /// 2. Generate N random floats in [-1, 1]
131    /// 3. Normalize to unit vector
132    fn generate_embedding(&self, text: &str) -> Vec<f32> {
133        // Hash text to get deterministic seed
134        let text_seed = self.hash_text(text);
135        let mut rng = DeterministicRng::new(text_seed);
136
137        // Generate random values in [-1, 1]
138        let mut embedding: Vec<f32> = (0..self.dimensions)
139            .map(|_| {
140                let val = rng.next_float();
141                (val * 2.0 - 1.0) as f32 // Map [0, 1] to [-1, 1]
142            })
143            .collect();
144
145        // Normalize to unit vector
146        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
147        if norm > 0.0 {
148            for val in &mut embedding {
149                *val /= norm;
150            }
151        }
152
153        // Postcondition: embedding is normalized
154        debug_assert!(
155            {
156                let check_norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
157                (check_norm - 1.0).abs() < 0.001
158            },
159            "embedding must be normalized to unit vector"
160        );
161        debug_assert_eq!(
162            embedding.len(),
163            self.dimensions,
164            "embedding must have correct dimensions"
165        );
166
167        embedding
168    }
169}
170
171#[async_trait]
172impl EmbeddingProvider for SimEmbeddingProvider {
173    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
174        // Precondition: text must not be empty
175        if text.is_empty() {
176            return Err(EmbeddingError::EmptyInput);
177        }
178
179        // Fault injection
180        if self.should_inject_fault("embedding_timeout") {
181            return Err(EmbeddingError::Timeout);
182        }
183        if self.should_inject_fault("embedding_rate_limit") {
184            return Err(EmbeddingError::rate_limit(Some(60)));
185        }
186        if self.should_inject_fault("embedding_service_unavailable") {
187            return Err(EmbeddingError::service_unavailable("Simulated failure"));
188        }
189
190        Ok(self.generate_embedding(text))
191    }
192
193    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
194        // Precondition: batch must not be empty
195        if texts.is_empty() {
196            return Err(EmbeddingError::invalid_request("batch cannot be empty"));
197        }
198
199        // Fault injection (same as embed)
200        if self.should_inject_fault("embedding_timeout") {
201            return Err(EmbeddingError::Timeout);
202        }
203        if self.should_inject_fault("embedding_rate_limit") {
204            return Err(EmbeddingError::rate_limit(Some(60)));
205        }
206        if self.should_inject_fault("embedding_service_unavailable") {
207            return Err(EmbeddingError::service_unavailable("Simulated failure"));
208        }
209
210        // Generate embedding for each text
211        let mut embeddings = Vec::with_capacity(texts.len());
212        for text in texts {
213            if text.is_empty() {
214                return Err(EmbeddingError::EmptyInput);
215            }
216            embeddings.push(self.generate_embedding(text));
217        }
218
219        // Postcondition: same number of embeddings as inputs
220        debug_assert_eq!(
221            embeddings.len(),
222            texts.len(),
223            "must return one embedding per input"
224        );
225
226        Ok(embeddings)
227    }
228
229    fn dimensions(&self) -> usize {
230        self.dimensions
231    }
232
233    fn name(&self) -> &'static str {
234        "sim-embedding"
235    }
236
237    fn is_simulation(&self) -> bool {
238        true
239    }
240}
241
242// =============================================================================
243// Tests
244// =============================================================================
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[tokio::test]
251    async fn test_sim_embedding_basic() {
252        let provider = SimEmbeddingProvider::new(42);
253        let embedding = provider.embed("Alice works at Acme").await.unwrap();
254
255        assert_eq!(embedding.len(), EMBEDDING_DIMENSIONS_COUNT);
256    }
257
258    #[tokio::test]
259    async fn test_sim_embedding_deterministic() {
260        let provider = SimEmbeddingProvider::new(42);
261
262        let emb1 = provider.embed("Alice works at Acme").await.unwrap();
263        let emb2 = provider.embed("Alice works at Acme").await.unwrap();
264
265        // Same text should produce identical embeddings
266        assert_eq!(emb1, emb2);
267    }
268
269    #[tokio::test]
270    async fn test_sim_embedding_different_text() {
271        let provider = SimEmbeddingProvider::new(42);
272
273        let emb1 = provider.embed("Alice works at Acme").await.unwrap();
274        let emb2 = provider.embed("Bob works at TechCo").await.unwrap();
275
276        // Different text should produce different embeddings
277        assert_ne!(emb1, emb2);
278    }
279
280    #[tokio::test]
281    async fn test_sim_embedding_different_seed() {
282        let provider1 = SimEmbeddingProvider::new(42);
283        let provider2 = SimEmbeddingProvider::new(99);
284
285        let emb1 = provider1.embed("Alice works at Acme").await.unwrap();
286        let emb2 = provider2.embed("Alice works at Acme").await.unwrap();
287
288        // Different seed should produce different embeddings
289        assert_ne!(emb1, emb2);
290    }
291
292    #[tokio::test]
293    async fn test_sim_embedding_normalized() {
294        let provider = SimEmbeddingProvider::new(42);
295        let embedding = provider.embed("Alice works at Acme").await.unwrap();
296
297        // Check that embedding is normalized (L2 norm ≈ 1.0)
298        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
299        assert!((norm - 1.0).abs() < 0.001, "embedding must be normalized");
300    }
301
302    #[tokio::test]
303    async fn test_sim_embedding_empty_text() {
304        let provider = SimEmbeddingProvider::new(42);
305        let result = provider.embed("").await;
306
307        assert!(matches!(result, Err(EmbeddingError::EmptyInput)));
308    }
309
310    #[tokio::test]
311    async fn test_sim_embedding_batch() {
312        let provider = SimEmbeddingProvider::new(42);
313        let texts = vec!["Alice works at Acme", "Bob works at TechCo"];
314
315        let embeddings = provider.embed_batch(&texts).await.unwrap();
316
317        assert_eq!(embeddings.len(), 2);
318        assert_eq!(embeddings[0].len(), EMBEDDING_DIMENSIONS_COUNT);
319        assert_eq!(embeddings[1].len(), EMBEDDING_DIMENSIONS_COUNT);
320
321        // Should match individual embeds
322        let single1 = provider.embed(texts[0]).await.unwrap();
323        let single2 = provider.embed(texts[1]).await.unwrap();
324
325        assert_eq!(embeddings[0], single1);
326        assert_eq!(embeddings[1], single2);
327    }
328
329    #[tokio::test]
330    async fn test_sim_embedding_batch_empty() {
331        let provider = SimEmbeddingProvider::new(42);
332        let texts: Vec<&str> = vec![];
333
334        let result = provider.embed_batch(&texts).await;
335        assert!(result.is_err());
336    }
337
338    #[tokio::test]
339    async fn test_sim_embedding_batch_with_empty_text() {
340        let provider = SimEmbeddingProvider::new(42);
341        let texts = vec!["Alice", ""];
342
343        let result = provider.embed_batch(&texts).await;
344        assert!(matches!(result, Err(EmbeddingError::EmptyInput)));
345    }
346
347    #[tokio::test]
348    async fn test_sim_embedding_provider_traits() {
349        let provider = SimEmbeddingProvider::new(42);
350
351        assert_eq!(provider.dimensions(), EMBEDDING_DIMENSIONS_COUNT);
352        assert_eq!(provider.name(), "sim-embedding");
353        assert!(provider.is_simulation());
354    }
355
356    // =========================================================================
357    // Determinism Property Tests
358    // =========================================================================
359
360    #[tokio::test]
361    async fn test_determinism_same_seed_same_results() {
362        async fn run_with_seed(seed: u64) -> Vec<f32> {
363            let provider = SimEmbeddingProvider::new(seed);
364            provider.embed("test text").await.unwrap()
365        }
366
367        let result1 = run_with_seed(42).await;
368        let result2 = run_with_seed(42).await;
369
370        assert_eq!(result1, result2, "same seed must produce same results");
371    }
372
373    #[tokio::test]
374    async fn test_determinism_different_seed_different_results() {
375        let provider1 = SimEmbeddingProvider::new(42);
376        let provider2 = SimEmbeddingProvider::new(43);
377
378        let result1 = provider1.embed("test text").await.unwrap();
379        let result2 = provider2.embed("test text").await.unwrap();
380
381        assert_ne!(
382            result1, result2,
383            "different seeds must produce different results"
384        );
385    }
386
387    #[tokio::test]
388    async fn test_batch_determinism() {
389        let provider = SimEmbeddingProvider::new(42);
390        let texts = vec!["text1", "text2", "text3"];
391
392        let batch1 = provider.embed_batch(&texts).await.unwrap();
393        let batch2 = provider.embed_batch(&texts).await.unwrap();
394
395        assert_eq!(batch1, batch2, "batch must be deterministic");
396    }
397
398    // =========================================================================
399    // Normalization Property Tests
400    // =========================================================================
401
402    #[tokio::test]
403    async fn test_all_embeddings_normalized() {
404        let provider = SimEmbeddingProvider::new(42);
405        let texts = vec![
406            "short",
407            "longer text here",
408            "even longer text with more words to test different lengths",
409        ];
410
411        for text in texts {
412            let embedding = provider.embed(text).await.unwrap();
413            let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
414            assert!(
415                (norm - 1.0).abs() < 0.001,
416                "embedding for '{}' must be normalized, got norm {}",
417                text,
418                norm
419            );
420        }
421    }
422
423    // =========================================================================
424    // Hash Function Tests
425    // =========================================================================
426
427    #[test]
428    fn test_hash_text_deterministic() {
429        let provider = SimEmbeddingProvider::new(42);
430
431        let hash1 = provider.hash_text("test");
432        let hash2 = provider.hash_text("test");
433
434        assert_eq!(hash1, hash2, "hash must be deterministic");
435    }
436
437    #[test]
438    fn test_hash_text_different_text() {
439        let provider = SimEmbeddingProvider::new(42);
440
441        let hash1 = provider.hash_text("test1");
442        let hash2 = provider.hash_text("test2");
443
444        assert_ne!(hash1, hash2, "different text must produce different hashes");
445    }
446
447    #[test]
448    fn test_hash_text_different_seed() {
449        let provider1 = SimEmbeddingProvider::new(42);
450        let provider2 = SimEmbeddingProvider::new(99);
451
452        let hash1 = provider1.hash_text("test");
453        let hash2 = provider2.hash_text("test");
454
455        assert_ne!(hash1, hash2, "different seed must produce different hashes");
456    }
457}