umi_memory/embedding/mod.rs
1//! Embedding Provider Trait - Unified Interface for Text Embeddings
2//!
3//! TigerStyle: Simulation-first embedding generation.
4//!
5//! See ADR-019 for design rationale.
6//!
7//! # Architecture
8//!
9//! ```text
10//! EmbeddingProvider (trait)
11//! ├── SimEmbeddingProvider (always available, deterministic)
12//! └── OpenAIEmbeddingProvider (feature: embedding-openai)
13//! ```
14//!
15//! # Usage
16//!
17//! ```rust
18//! use umi_memory::embedding::{EmbeddingProvider, SimEmbeddingProvider};
19//!
20//! #[tokio::main]
21//! async fn main() {
22//! // Simulation (always available, deterministic)
23//! let provider = SimEmbeddingProvider::with_seed(42);
24//!
25//! let embedding = provider.embed("Alice works at Acme").await.unwrap();
26//! println!("Generated {} dimensional embedding", embedding.len());
27//! }
28//! ```
29
30mod sim;
31
32#[cfg(feature = "embedding-openai")]
33mod openai;
34
35pub use sim::SimEmbeddingProvider;
36
37#[cfg(feature = "embedding-openai")]
38pub use openai::OpenAIEmbeddingProvider;
39
40use async_trait::async_trait;
41
42// =============================================================================
43// Error Types
44// =============================================================================
45
46/// Unified error type for all embedding providers.
47///
48/// TigerStyle: Explicit variants for all failure modes.
49#[derive(Debug, Clone, thiserror::Error)]
50pub enum EmbeddingError {
51 /// Request timed out
52 #[error("Request timed out")]
53 Timeout,
54
55 /// Rate limit exceeded
56 #[error("Rate limit exceeded, retry after {retry_after_secs:?}s")]
57 RateLimit {
58 /// Seconds until rate limit resets (if known)
59 retry_after_secs: Option<u64>,
60 },
61
62 /// Input text too long
63 #[error("Context length exceeded: {tokens} tokens")]
64 ContextOverflow {
65 /// Number of tokens that exceeded the limit
66 tokens: usize,
67 },
68
69 /// Invalid response from provider
70 #[error("Invalid response: {message}")]
71 InvalidResponse {
72 /// Description of what was invalid
73 message: String,
74 },
75
76 /// Service unavailable
77 #[error("Service unavailable: {message}")]
78 ServiceUnavailable {
79 /// Reason for unavailability
80 message: String,
81 },
82
83 /// Authentication failed
84 #[error("Authentication failed")]
85 AuthenticationFailed,
86
87 /// JSON serialization/deserialization error
88 #[error("JSON error: {message}")]
89 JsonError {
90 /// Description of the JSON error
91 message: String,
92 },
93
94 /// Network error
95 #[error("Network error: {message}")]
96 NetworkError {
97 /// Description of the network error
98 message: String,
99 },
100
101 /// Invalid request parameters
102 #[error("Invalid request: {message}")]
103 InvalidRequest {
104 /// Description of what was invalid
105 message: String,
106 },
107
108 /// Empty input provided
109 #[error("Empty input provided")]
110 EmptyInput,
111
112 /// Dimension mismatch in returned embedding
113 #[error("Dimension mismatch: expected {expected}, got {actual}")]
114 DimensionMismatch {
115 /// Expected dimensions
116 expected: usize,
117 /// Actual dimensions received
118 actual: usize,
119 },
120}
121
122impl EmbeddingError {
123 /// Create a timeout error.
124 #[must_use]
125 pub fn timeout() -> Self {
126 Self::Timeout
127 }
128
129 /// Create a rate limit error.
130 #[must_use]
131 pub fn rate_limit(retry_after_secs: Option<u64>) -> Self {
132 Self::RateLimit { retry_after_secs }
133 }
134
135 /// Create a context overflow error.
136 #[must_use]
137 pub fn context_overflow(tokens: usize) -> Self {
138 Self::ContextOverflow { tokens }
139 }
140
141 /// Create an invalid response error.
142 #[must_use]
143 pub fn invalid_response(message: impl Into<String>) -> Self {
144 Self::InvalidResponse {
145 message: message.into(),
146 }
147 }
148
149 /// Create a service unavailable error.
150 #[must_use]
151 pub fn service_unavailable(message: impl Into<String>) -> Self {
152 Self::ServiceUnavailable {
153 message: message.into(),
154 }
155 }
156
157 /// Create a JSON error.
158 #[must_use]
159 pub fn json_error(message: impl Into<String>) -> Self {
160 Self::JsonError {
161 message: message.into(),
162 }
163 }
164
165 /// Create a network error.
166 #[must_use]
167 pub fn network_error(message: impl Into<String>) -> Self {
168 Self::NetworkError {
169 message: message.into(),
170 }
171 }
172
173 /// Create an invalid request error.
174 #[must_use]
175 pub fn invalid_request(message: impl Into<String>) -> Self {
176 Self::InvalidRequest {
177 message: message.into(),
178 }
179 }
180
181 /// Create a dimension mismatch error.
182 #[must_use]
183 pub fn dimension_mismatch(expected: usize, actual: usize) -> Self {
184 Self::DimensionMismatch { expected, actual }
185 }
186
187 /// Check if this error is retryable.
188 #[must_use]
189 pub fn is_retryable(&self) -> bool {
190 matches!(
191 self,
192 Self::Timeout | Self::RateLimit { .. } | Self::ServiceUnavailable { .. }
193 )
194 }
195}
196
197// =============================================================================
198// Provider Trait
199// =============================================================================
200
201/// Trait for embedding providers.
202///
203/// TigerStyle: Unified interface for simulation and production.
204///
205/// All providers implement this trait, allowing higher-level components
206/// to work with any provider without knowing the concrete type.
207///
208/// # Example
209///
210/// ```rust
211/// use umi_memory::embedding::{EmbeddingProvider, SimEmbeddingProvider};
212///
213/// async fn generate_embedding<P: EmbeddingProvider>(provider: &P, text: &str) -> Vec<f32> {
214/// provider.embed(text).await.unwrap()
215/// }
216/// ```
217#[async_trait]
218pub trait EmbeddingProvider: Send + Sync {
219 /// Generate embedding for a single text.
220 ///
221 /// # Arguments
222 /// * `text` - The text to embed
223 ///
224 /// # Returns
225 /// Vector of floats representing the embedding (normalized to unit vector)
226 ///
227 /// # Errors
228 /// Returns `EmbeddingError` on failure (rate limit, network error, etc.)
229 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
230
231 /// Generate embeddings for multiple texts (batch).
232 ///
233 /// This is more efficient than calling `embed` multiple times as it can
234 /// leverage API batching capabilities.
235 ///
236 /// # Arguments
237 /// * `texts` - Slice of texts to embed
238 ///
239 /// # Returns
240 /// Vector of embeddings, one per input text
241 ///
242 /// # Errors
243 /// Returns `EmbeddingError` on failure
244 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
245
246 /// Get the embedding dimensions (e.g., 1536 for text-embedding-3-small).
247 fn dimensions(&self) -> usize;
248
249 /// Provider name for logging and debugging.
250 fn name(&self) -> &'static str;
251
252 /// Check if this is a simulation provider.
253 ///
254 /// Returns `true` for `SimEmbeddingProvider`, `false` for real providers.
255 fn is_simulation(&self) -> bool;
256}
257
258// =============================================================================
259// Helper Functions
260// =============================================================================
261
262/// Validate that an embedding has the expected dimensions.
263///
264/// # Arguments
265/// * `embedding` - The embedding to validate
266/// * `expected` - Expected number of dimensions
267///
268/// # Errors
269/// Returns `EmbeddingError::DimensionMismatch` if dimensions don't match
270pub fn validate_dimensions(embedding: &[f32], expected: usize) -> Result<(), EmbeddingError> {
271 if embedding.len() != expected {
272 return Err(EmbeddingError::dimension_mismatch(
273 expected,
274 embedding.len(),
275 ));
276 }
277 Ok(())
278}
279
280/// Normalize a vector to unit length (L2 norm = 1).
281///
282/// This ensures cosine similarity can be computed efficiently.
283///
284/// # Arguments
285/// * `vec` - The vector to normalize
286///
287/// # Panics
288/// Panics if the input vector is all zeros (cannot normalize)
289#[must_use]
290pub fn normalize_vector(vec: &[f32]) -> Vec<f32> {
291 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
292
293 assert!(norm > 0.0, "Cannot normalize zero vector");
294
295 vec.iter().map(|x| x / norm).collect()
296}
297
298/// Check if a vector is normalized (L2 norm ≈ 1.0).
299///
300/// # Arguments
301/// * `vec` - The vector to check
302/// * `tolerance` - Acceptable deviation from 1.0 (default: 0.001)
303#[must_use]
304pub fn is_normalized(vec: &[f32], tolerance: f32) -> bool {
305 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
306 (norm - 1.0).abs() < tolerance
307}
308
309// =============================================================================
310// Tests
311// =============================================================================
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use crate::constants::EMBEDDING_DIMENSIONS_COUNT;
317
318 #[test]
319 fn test_embedding_error_constructors() {
320 let err = EmbeddingError::timeout();
321 assert!(matches!(err, EmbeddingError::Timeout));
322
323 let err = EmbeddingError::rate_limit(Some(60));
324 assert!(matches!(
325 err,
326 EmbeddingError::RateLimit {
327 retry_after_secs: Some(60)
328 }
329 ));
330
331 let err = EmbeddingError::context_overflow(10000);
332 assert!(matches!(
333 err,
334 EmbeddingError::ContextOverflow { tokens: 10000 }
335 ));
336
337 let err = EmbeddingError::invalid_response("bad format");
338 assert!(matches!(err, EmbeddingError::InvalidResponse { .. }));
339
340 let err = EmbeddingError::dimension_mismatch(1536, 768);
341 assert!(matches!(
342 err,
343 EmbeddingError::DimensionMismatch {
344 expected: 1536,
345 actual: 768
346 }
347 ));
348 }
349
350 #[test]
351 fn test_embedding_error_is_retryable() {
352 assert!(EmbeddingError::timeout().is_retryable());
353 assert!(EmbeddingError::rate_limit(Some(60)).is_retryable());
354 assert!(EmbeddingError::service_unavailable("down").is_retryable());
355
356 assert!(!EmbeddingError::AuthenticationFailed.is_retryable());
357 assert!(!EmbeddingError::EmptyInput.is_retryable());
358 assert!(!EmbeddingError::json_error("parse failed").is_retryable());
359 }
360
361 #[test]
362 fn test_validate_dimensions() {
363 let embedding = vec![0.1; EMBEDDING_DIMENSIONS_COUNT];
364 assert!(validate_dimensions(&embedding, EMBEDDING_DIMENSIONS_COUNT).is_ok());
365
366 let wrong_size = vec![0.1; 768];
367 assert!(validate_dimensions(&wrong_size, EMBEDDING_DIMENSIONS_COUNT).is_err());
368 }
369
370 #[test]
371 fn test_normalize_vector() {
372 let vec = vec![3.0, 4.0]; // Length = 5
373 let normalized = normalize_vector(&vec);
374
375 // Should be [0.6, 0.8]
376 assert!((normalized[0] - 0.6).abs() < 0.001);
377 assert!((normalized[1] - 0.8).abs() < 0.001);
378
379 // Verify unit length
380 assert!(is_normalized(&normalized, 0.001));
381 }
382
383 #[test]
384 fn test_is_normalized() {
385 let unit = vec![1.0, 0.0, 0.0];
386 assert!(is_normalized(&unit, 0.001));
387
388 let not_unit = vec![2.0, 0.0, 0.0];
389 assert!(!is_normalized(¬_unit, 0.001));
390
391 let normalized = vec![0.6, 0.8]; // 3-4-5 triangle
392 assert!(is_normalized(&normalized, 0.001));
393 }
394
395 #[test]
396 #[should_panic(expected = "Cannot normalize zero vector")]
397 fn test_normalize_zero_vector() {
398 let zero = vec![0.0, 0.0, 0.0];
399 let _ = normalize_vector(&zero);
400 }
401}