umi_memory/embedding/
mod.rs

1//! Embedding Provider Trait - Unified Interface for Text Embeddings
2//!
3//! TigerStyle: Simulation-first embedding generation.
4//!
5//! See ADR-019 for design rationale.
6//!
7//! # Architecture
8//!
9//! ```text
10//! EmbeddingProvider (trait)
11//! ├── SimEmbeddingProvider    (always available, deterministic)
12//! └── OpenAIEmbeddingProvider (feature: embedding-openai)
13//! ```
14//!
15//! # Usage
16//!
17//! ```rust
18//! use umi_memory::embedding::{EmbeddingProvider, SimEmbeddingProvider};
19//!
20//! #[tokio::main]
21//! async fn main() {
22//!     // Simulation (always available, deterministic)
23//!     let provider = SimEmbeddingProvider::with_seed(42);
24//!
25//!     let embedding = provider.embed("Alice works at Acme").await.unwrap();
26//!     println!("Generated {} dimensional embedding", embedding.len());
27//! }
28//! ```
29
30mod sim;
31
32#[cfg(feature = "embedding-openai")]
33mod openai;
34
35pub use sim::SimEmbeddingProvider;
36
37#[cfg(feature = "embedding-openai")]
38pub use openai::OpenAIEmbeddingProvider;
39
40use async_trait::async_trait;
41
42// =============================================================================
43// Error Types
44// =============================================================================
45
46/// Unified error type for all embedding providers.
47///
48/// TigerStyle: Explicit variants for all failure modes.
49#[derive(Debug, Clone, thiserror::Error)]
50pub enum EmbeddingError {
51    /// Request timed out
52    #[error("Request timed out")]
53    Timeout,
54
55    /// Rate limit exceeded
56    #[error("Rate limit exceeded, retry after {retry_after_secs:?}s")]
57    RateLimit {
58        /// Seconds until rate limit resets (if known)
59        retry_after_secs: Option<u64>,
60    },
61
62    /// Input text too long
63    #[error("Context length exceeded: {tokens} tokens")]
64    ContextOverflow {
65        /// Number of tokens that exceeded the limit
66        tokens: usize,
67    },
68
69    /// Invalid response from provider
70    #[error("Invalid response: {message}")]
71    InvalidResponse {
72        /// Description of what was invalid
73        message: String,
74    },
75
76    /// Service unavailable
77    #[error("Service unavailable: {message}")]
78    ServiceUnavailable {
79        /// Reason for unavailability
80        message: String,
81    },
82
83    /// Authentication failed
84    #[error("Authentication failed")]
85    AuthenticationFailed,
86
87    /// JSON serialization/deserialization error
88    #[error("JSON error: {message}")]
89    JsonError {
90        /// Description of the JSON error
91        message: String,
92    },
93
94    /// Network error
95    #[error("Network error: {message}")]
96    NetworkError {
97        /// Description of the network error
98        message: String,
99    },
100
101    /// Invalid request parameters
102    #[error("Invalid request: {message}")]
103    InvalidRequest {
104        /// Description of what was invalid
105        message: String,
106    },
107
108    /// Empty input provided
109    #[error("Empty input provided")]
110    EmptyInput,
111
112    /// Dimension mismatch in returned embedding
113    #[error("Dimension mismatch: expected {expected}, got {actual}")]
114    DimensionMismatch {
115        /// Expected dimensions
116        expected: usize,
117        /// Actual dimensions received
118        actual: usize,
119    },
120}
121
122impl EmbeddingError {
123    /// Create a timeout error.
124    #[must_use]
125    pub fn timeout() -> Self {
126        Self::Timeout
127    }
128
129    /// Create a rate limit error.
130    #[must_use]
131    pub fn rate_limit(retry_after_secs: Option<u64>) -> Self {
132        Self::RateLimit { retry_after_secs }
133    }
134
135    /// Create a context overflow error.
136    #[must_use]
137    pub fn context_overflow(tokens: usize) -> Self {
138        Self::ContextOverflow { tokens }
139    }
140
141    /// Create an invalid response error.
142    #[must_use]
143    pub fn invalid_response(message: impl Into<String>) -> Self {
144        Self::InvalidResponse {
145            message: message.into(),
146        }
147    }
148
149    /// Create a service unavailable error.
150    #[must_use]
151    pub fn service_unavailable(message: impl Into<String>) -> Self {
152        Self::ServiceUnavailable {
153            message: message.into(),
154        }
155    }
156
157    /// Create a JSON error.
158    #[must_use]
159    pub fn json_error(message: impl Into<String>) -> Self {
160        Self::JsonError {
161            message: message.into(),
162        }
163    }
164
165    /// Create a network error.
166    #[must_use]
167    pub fn network_error(message: impl Into<String>) -> Self {
168        Self::NetworkError {
169            message: message.into(),
170        }
171    }
172
173    /// Create an invalid request error.
174    #[must_use]
175    pub fn invalid_request(message: impl Into<String>) -> Self {
176        Self::InvalidRequest {
177            message: message.into(),
178        }
179    }
180
181    /// Create a dimension mismatch error.
182    #[must_use]
183    pub fn dimension_mismatch(expected: usize, actual: usize) -> Self {
184        Self::DimensionMismatch { expected, actual }
185    }
186
187    /// Check if this error is retryable.
188    #[must_use]
189    pub fn is_retryable(&self) -> bool {
190        matches!(
191            self,
192            Self::Timeout | Self::RateLimit { .. } | Self::ServiceUnavailable { .. }
193        )
194    }
195}
196
197// =============================================================================
198// Provider Trait
199// =============================================================================
200
201/// Trait for embedding providers.
202///
203/// TigerStyle: Unified interface for simulation and production.
204///
205/// All providers implement this trait, allowing higher-level components
206/// to work with any provider without knowing the concrete type.
207///
208/// # Example
209///
210/// ```rust
211/// use umi_memory::embedding::{EmbeddingProvider, SimEmbeddingProvider};
212///
213/// async fn generate_embedding<P: EmbeddingProvider>(provider: &P, text: &str) -> Vec<f32> {
214///     provider.embed(text).await.unwrap()
215/// }
216/// ```
217#[async_trait]
218pub trait EmbeddingProvider: Send + Sync {
219    /// Generate embedding for a single text.
220    ///
221    /// # Arguments
222    /// * `text` - The text to embed
223    ///
224    /// # Returns
225    /// Vector of floats representing the embedding (normalized to unit vector)
226    ///
227    /// # Errors
228    /// Returns `EmbeddingError` on failure (rate limit, network error, etc.)
229    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
230
231    /// Generate embeddings for multiple texts (batch).
232    ///
233    /// This is more efficient than calling `embed` multiple times as it can
234    /// leverage API batching capabilities.
235    ///
236    /// # Arguments
237    /// * `texts` - Slice of texts to embed
238    ///
239    /// # Returns
240    /// Vector of embeddings, one per input text
241    ///
242    /// # Errors
243    /// Returns `EmbeddingError` on failure
244    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
245
246    /// Get the embedding dimensions (e.g., 1536 for text-embedding-3-small).
247    fn dimensions(&self) -> usize;
248
249    /// Provider name for logging and debugging.
250    fn name(&self) -> &'static str;
251
252    /// Check if this is a simulation provider.
253    ///
254    /// Returns `true` for `SimEmbeddingProvider`, `false` for real providers.
255    fn is_simulation(&self) -> bool;
256}
257
258// =============================================================================
259// Helper Functions
260// =============================================================================
261
262/// Validate that an embedding has the expected dimensions.
263///
264/// # Arguments
265/// * `embedding` - The embedding to validate
266/// * `expected` - Expected number of dimensions
267///
268/// # Errors
269/// Returns `EmbeddingError::DimensionMismatch` if dimensions don't match
270pub fn validate_dimensions(embedding: &[f32], expected: usize) -> Result<(), EmbeddingError> {
271    if embedding.len() != expected {
272        return Err(EmbeddingError::dimension_mismatch(
273            expected,
274            embedding.len(),
275        ));
276    }
277    Ok(())
278}
279
280/// Normalize a vector to unit length (L2 norm = 1).
281///
282/// This ensures cosine similarity can be computed efficiently.
283///
284/// # Arguments
285/// * `vec` - The vector to normalize
286///
287/// # Panics
288/// Panics if the input vector is all zeros (cannot normalize)
289#[must_use]
290pub fn normalize_vector(vec: &[f32]) -> Vec<f32> {
291    let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
292
293    assert!(norm > 0.0, "Cannot normalize zero vector");
294
295    vec.iter().map(|x| x / norm).collect()
296}
297
298/// Check if a vector is normalized (L2 norm ≈ 1.0).
299///
300/// # Arguments
301/// * `vec` - The vector to check
302/// * `tolerance` - Acceptable deviation from 1.0 (default: 0.001)
303#[must_use]
304pub fn is_normalized(vec: &[f32], tolerance: f32) -> bool {
305    let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
306    (norm - 1.0).abs() < tolerance
307}
308
309// =============================================================================
310// Tests
311// =============================================================================
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::constants::EMBEDDING_DIMENSIONS_COUNT;
317
318    #[test]
319    fn test_embedding_error_constructors() {
320        let err = EmbeddingError::timeout();
321        assert!(matches!(err, EmbeddingError::Timeout));
322
323        let err = EmbeddingError::rate_limit(Some(60));
324        assert!(matches!(
325            err,
326            EmbeddingError::RateLimit {
327                retry_after_secs: Some(60)
328            }
329        ));
330
331        let err = EmbeddingError::context_overflow(10000);
332        assert!(matches!(
333            err,
334            EmbeddingError::ContextOverflow { tokens: 10000 }
335        ));
336
337        let err = EmbeddingError::invalid_response("bad format");
338        assert!(matches!(err, EmbeddingError::InvalidResponse { .. }));
339
340        let err = EmbeddingError::dimension_mismatch(1536, 768);
341        assert!(matches!(
342            err,
343            EmbeddingError::DimensionMismatch {
344                expected: 1536,
345                actual: 768
346            }
347        ));
348    }
349
350    #[test]
351    fn test_embedding_error_is_retryable() {
352        assert!(EmbeddingError::timeout().is_retryable());
353        assert!(EmbeddingError::rate_limit(Some(60)).is_retryable());
354        assert!(EmbeddingError::service_unavailable("down").is_retryable());
355
356        assert!(!EmbeddingError::AuthenticationFailed.is_retryable());
357        assert!(!EmbeddingError::EmptyInput.is_retryable());
358        assert!(!EmbeddingError::json_error("parse failed").is_retryable());
359    }
360
361    #[test]
362    fn test_validate_dimensions() {
363        let embedding = vec![0.1; EMBEDDING_DIMENSIONS_COUNT];
364        assert!(validate_dimensions(&embedding, EMBEDDING_DIMENSIONS_COUNT).is_ok());
365
366        let wrong_size = vec![0.1; 768];
367        assert!(validate_dimensions(&wrong_size, EMBEDDING_DIMENSIONS_COUNT).is_err());
368    }
369
370    #[test]
371    fn test_normalize_vector() {
372        let vec = vec![3.0, 4.0]; // Length = 5
373        let normalized = normalize_vector(&vec);
374
375        // Should be [0.6, 0.8]
376        assert!((normalized[0] - 0.6).abs() < 0.001);
377        assert!((normalized[1] - 0.8).abs() < 0.001);
378
379        // Verify unit length
380        assert!(is_normalized(&normalized, 0.001));
381    }
382
383    #[test]
384    fn test_is_normalized() {
385        let unit = vec![1.0, 0.0, 0.0];
386        assert!(is_normalized(&unit, 0.001));
387
388        let not_unit = vec![2.0, 0.0, 0.0];
389        assert!(!is_normalized(&not_unit, 0.001));
390
391        let normalized = vec![0.6, 0.8]; // 3-4-5 triangle
392        assert!(is_normalized(&normalized, 0.001));
393    }
394
395    #[test]
396    #[should_panic(expected = "Cannot normalize zero vector")]
397    fn test_normalize_zero_vector() {
398        let zero = vec![0.0, 0.0, 0.0];
399        let _ = normalize_vector(&zero);
400    }
401}