sochdb_query/
embedding_provider.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Automatic Embedding Generation (Task 2)
16//!
17//! This module provides colocated embedding resolution for text-to-vector conversion.
18//! It enables first-class text search by automatically generating embeddings.
19//!
20//! ## Design
21//!
22//! ```text
23//! search_text(collection, text, k)
24//!     │
25//!     ▼
26//! ┌─────────────────┐
27//! │ EmbeddingProvider │
28//! │  ├─ LRU Cache    │
29//! │  └─ ONNX Runtime │
30//! └─────────────────┘
31//!     │
32//!     ▼
33//! search_by_embedding(collection, embedding, k)
34//! ```
35//!
36//! ## Providers
37//!
38//! - `LocalProvider`: Uses FastEmbed/ONNX for offline embedding
39//! - `CachedProvider`: LRU cache wrapper for any provider
40//! - `MockProvider`: For testing
41//!
42//! ## Complexity
43//!
44//! - Embedding generation: O(n) where n = text length (transformer inference)
45//! - Cache lookup: O(1) expected (hash-based LRU)
46//! - Batch embedding: O(k) compute with ~O(1) ONNX session overhead
47
48use std::sync::Arc;
49use moka::sync::Cache;
50
51// ============================================================================
52// Embedding Provider Trait
53// ============================================================================
54
55/// Error type for embedding operations
56#[derive(Debug, Clone)]
57pub enum EmbeddingError {
58    /// Model not loaded or unavailable
59    ModelNotAvailable(String),
60    /// Text too long for model
61    TextTooLong { max_length: usize, actual: usize },
62    /// Dimension mismatch
63    DimensionMismatch { expected: usize, actual: usize },
64    /// Provider error
65    ProviderError(String),
66    /// Cache error
67    CacheError(String),
68}
69
70impl std::fmt::Display for EmbeddingError {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        match self {
73            Self::ModelNotAvailable(model) => write!(f, "Embedding model not available: {}", model),
74            Self::TextTooLong { max_length, actual } => {
75                write!(f, "Text too long: {} > {} max", actual, max_length)
76            }
77            Self::DimensionMismatch { expected, actual } => {
78                write!(f, "Dimension mismatch: expected {}, got {}", expected, actual)
79            }
80            Self::ProviderError(msg) => write!(f, "Provider error: {}", msg),
81            Self::CacheError(msg) => write!(f, "Cache error: {}", msg),
82        }
83    }
84}
85
86impl std::error::Error for EmbeddingError {}
87
88/// Result type for embedding operations
89pub type EmbeddingResult<T> = Result<T, EmbeddingError>;
90
91/// Embedding provider trait
92pub trait EmbeddingProvider: Send + Sync {
93    /// Get the model name
94    fn model_name(&self) -> &str;
95    
96    /// Get the embedding dimension
97    fn dimension(&self) -> usize;
98    
99    /// Maximum text length (in characters or tokens)
100    fn max_length(&self) -> usize;
101    
102    /// Generate embedding for a single text
103    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>>;
104    
105    /// Generate embeddings for multiple texts (batch)
106    fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
107        // Default implementation: sequential embedding
108        texts.iter().map(|t| self.embed(t)).collect()
109    }
110    
111    /// Normalize an embedding vector (L2 normalization)
112    fn normalize(&self, embedding: &mut [f32]) {
113        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
114        if norm > 1e-10 {
115            for x in embedding.iter_mut() {
116                *x /= norm;
117            }
118        }
119    }
120}
121
122// ============================================================================
123// Embedding Configuration
124// ============================================================================
125
126/// Configuration for embedding providers
127#[derive(Debug, Clone)]
128pub struct EmbeddingConfig {
129    /// Model identifier (e.g., "all-MiniLM-L6-v2")
130    pub model: String,
131    
132    /// Model path (for local ONNX models)
133    pub model_path: Option<String>,
134    
135    /// Embedding dimension
136    pub dimension: usize,
137    
138    /// Maximum text length
139    pub max_length: usize,
140    
141    /// Whether to normalize embeddings
142    pub normalize: bool,
143    
144    /// Batch size for embedding generation
145    pub batch_size: usize,
146    
147    /// Cache size (number of embeddings to cache)
148    pub cache_size: usize,
149    
150    /// Cache TTL in seconds (0 = no expiry)
151    pub cache_ttl_secs: u64,
152}
153
154impl Default for EmbeddingConfig {
155    fn default() -> Self {
156        Self {
157            model: "all-MiniLM-L6-v2".to_string(),
158            model_path: None,
159            dimension: 384, // MiniLM dimension
160            max_length: 512,
161            normalize: true,
162            batch_size: 32,
163            cache_size: 10_000,
164            cache_ttl_secs: 3600, // 1 hour
165        }
166    }
167}
168
169impl EmbeddingConfig {
170    /// Create config for sentence-transformers models
171    pub fn sentence_transformer(model: &str) -> Self {
172        let dimension = match model {
173            "all-MiniLM-L6-v2" => 384,
174            "all-MiniLM-L12-v2" => 384,
175            "all-mpnet-base-v2" => 768,
176            "paraphrase-MiniLM-L6-v2" => 384,
177            "multi-qa-MiniLM-L6-cos-v1" => 384,
178            _ => 384, // Default
179        };
180        
181        Self {
182            model: model.to_string(),
183            dimension,
184            ..Default::default()
185        }
186    }
187    
188    /// Create config for OpenAI-compatible models
189    pub fn openai(model: &str) -> Self {
190        let dimension = match model {
191            "text-embedding-ada-002" => 1536,
192            "text-embedding-3-small" => 1536,
193            "text-embedding-3-large" => 3072,
194            _ => 1536,
195        };
196        
197        Self {
198            model: model.to_string(),
199            dimension,
200            max_length: 8192,
201            ..Default::default()
202        }
203    }
204}
205
206// ============================================================================
207// Mock Embedding Provider (for testing)
208// ============================================================================
209
210/// Mock embedding provider for testing
211pub struct MockEmbeddingProvider {
212    config: EmbeddingConfig,
213    /// Deterministic embeddings based on text hash
214    use_hash: bool,
215}
216
217impl MockEmbeddingProvider {
218    /// Create a new mock provider
219    pub fn new(dimension: usize) -> Self {
220        Self {
221            config: EmbeddingConfig {
222                model: "mock".to_string(),
223                dimension,
224                ..Default::default()
225            },
226            use_hash: true,
227        }
228    }
229    
230    /// Create with custom config
231    pub fn with_config(config: EmbeddingConfig) -> Self {
232        Self {
233            config,
234            use_hash: true,
235        }
236    }
237    
238    /// Generate a deterministic embedding from text
239    fn hash_embed(&self, text: &str) -> Vec<f32> {
240        use std::hash::{Hash, Hasher};
241        use std::collections::hash_map::DefaultHasher;
242        
243        let mut embedding = Vec::with_capacity(self.config.dimension);
244        
245        // Generate pseudo-random values based on text hash
246        for i in 0..self.config.dimension {
247            let mut hasher = DefaultHasher::new();
248            text.hash(&mut hasher);
249            i.hash(&mut hasher);
250            let hash = hasher.finish();
251            
252            // Convert to f32 in range [-1, 1]
253            let value = ((hash as f64) / (u64::MAX as f64) * 2.0 - 1.0) as f32;
254            embedding.push(value);
255        }
256        
257        embedding
258    }
259}
260
261impl EmbeddingProvider for MockEmbeddingProvider {
262    fn model_name(&self) -> &str {
263        &self.config.model
264    }
265    
266    fn dimension(&self) -> usize {
267        self.config.dimension
268    }
269    
270    fn max_length(&self) -> usize {
271        self.config.max_length
272    }
273    
274    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
275        if text.len() > self.config.max_length {
276            return Err(EmbeddingError::TextTooLong {
277                max_length: self.config.max_length,
278                actual: text.len(),
279            });
280        }
281        
282        let mut embedding = if self.use_hash {
283            self.hash_embed(text)
284        } else {
285            vec![0.0; self.config.dimension]
286        };
287        
288        if self.config.normalize {
289            self.normalize(&mut embedding);
290        }
291        
292        Ok(embedding)
293    }
294}
295
296// ============================================================================
297// Cached Embedding Provider
298// ============================================================================
299
300/// LRU-cached embedding provider wrapper
301pub struct CachedEmbeddingProvider<P: EmbeddingProvider> {
302    /// Inner provider
303    inner: P,
304    
305    /// LRU cache: text hash -> embedding
306    cache: Cache<u64, Vec<f32>>,
307    
308    /// Cache statistics
309    stats: Arc<CacheStats>,
310}
311
312/// Cache statistics
313#[derive(Debug, Default)]
314pub struct CacheStats {
315    /// Number of cache hits
316    pub hits: std::sync::atomic::AtomicUsize,
317    /// Number of cache misses
318    pub misses: std::sync::atomic::AtomicUsize,
319    /// Number of embeddings cached
320    pub size: std::sync::atomic::AtomicUsize,
321}
322
323impl CacheStats {
324    /// Get hit rate
325    pub fn hit_rate(&self) -> f64 {
326        let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
327        let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
328        let total = hits + misses;
329        if total == 0 {
330            0.0
331        } else {
332            hits as f64 / total as f64
333        }
334    }
335}
336
337impl<P: EmbeddingProvider> CachedEmbeddingProvider<P> {
338    /// Create a new cached provider
339    pub fn new(inner: P, cache_size: usize) -> Self {
340        Self {
341            inner,
342            cache: Cache::new(cache_size as u64),
343            stats: Arc::new(CacheStats::default()),
344        }
345    }
346    
347    /// Create with TTL
348    pub fn with_ttl(inner: P, cache_size: usize, ttl_secs: u64) -> Self {
349        let cache = Cache::builder()
350            .max_capacity(cache_size as u64)
351            .time_to_live(std::time::Duration::from_secs(ttl_secs))
352            .build();
353        
354        Self {
355            inner,
356            cache,
357            stats: Arc::new(CacheStats::default()),
358        }
359    }
360    
361    /// Get cache statistics
362    pub fn stats(&self) -> &Arc<CacheStats> {
363        &self.stats
364    }
365    
366    /// Compute hash for cache key
367    fn text_hash(text: &str) -> u64 {
368        use std::hash::{Hash, Hasher};
369        use std::collections::hash_map::DefaultHasher;
370        
371        let mut hasher = DefaultHasher::new();
372        text.hash(&mut hasher);
373        hasher.finish()
374    }
375}
376
377impl<P: EmbeddingProvider> EmbeddingProvider for CachedEmbeddingProvider<P> {
378    fn model_name(&self) -> &str {
379        self.inner.model_name()
380    }
381    
382    fn dimension(&self) -> usize {
383        self.inner.dimension()
384    }
385    
386    fn max_length(&self) -> usize {
387        self.inner.max_length()
388    }
389    
390    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
391        let hash = Self::text_hash(text);
392        
393        // Check cache
394        if let Some(cached) = self.cache.get(&hash) {
395            self.stats.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
396            return Ok(cached);
397        }
398        
399        self.stats.misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
400        
401        // Generate embedding
402        let embedding = self.inner.embed(text)?;
403        
404        // Cache result
405        self.cache.insert(hash, embedding.clone());
406        self.stats.size.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
407        
408        Ok(embedding)
409    }
410    
411    fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
412        let mut results = Vec::with_capacity(texts.len());
413        let mut uncached: Vec<(usize, &str)> = Vec::new();
414        
415        // Check cache for each text
416        for (i, text) in texts.iter().enumerate() {
417            let hash = Self::text_hash(text);
418            if let Some(cached) = self.cache.get(&hash) {
419                self.stats.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
420                results.push((i, cached));
421            } else {
422                self.stats.misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
423                uncached.push((i, *text));
424            }
425        }
426        
427        // Generate embeddings for uncached texts
428        if !uncached.is_empty() {
429            let uncached_texts: Vec<&str> = uncached.iter().map(|(_, t)| *t).collect();
430            let embeddings = self.inner.embed_batch(&uncached_texts)?;
431            
432            for ((i, text), embedding) in uncached.iter().zip(embeddings.into_iter()) {
433                let hash = Self::text_hash(text);
434                self.cache.insert(hash, embedding.clone());
435                self.stats.size.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
436                results.push((*i, embedding));
437            }
438        }
439        
440        // Sort by original index
441        results.sort_by_key(|(i, _)| *i);
442        Ok(results.into_iter().map(|(_, e)| e).collect())
443    }
444}
445
446// ============================================================================
447// Local ONNX Provider (Stub)
448// ============================================================================
449
450/// Local ONNX-based embedding provider
451/// 
452/// This is a stub implementation. In production, this would use:
453/// - ort (ONNX Runtime) for model inference
454/// - fastembed-rs for pre-packaged models
455/// - tokenizers for text preprocessing
456#[derive(Debug)]
457pub struct LocalOnnxProvider {
458    config: EmbeddingConfig,
459    /// Model weights (placeholder)
460    #[allow(dead_code)]
461    model_loaded: bool,
462}
463
464impl LocalOnnxProvider {
465    /// Create a new local ONNX provider
466    pub fn new(config: EmbeddingConfig) -> EmbeddingResult<Self> {
467        // In production: Load ONNX model from path
468        Ok(Self {
469            config,
470            model_loaded: false,
471        })
472    }
473    
474    /// Load a pre-trained model by name
475    pub fn load_pretrained(model_name: &str) -> EmbeddingResult<Self> {
476        let config = EmbeddingConfig::sentence_transformer(model_name);
477        Self::new(config)
478    }
479}
480
481impl EmbeddingProvider for LocalOnnxProvider {
482    fn model_name(&self) -> &str {
483        &self.config.model
484    }
485    
486    fn dimension(&self) -> usize {
487        self.config.dimension
488    }
489    
490    fn max_length(&self) -> usize {
491        self.config.max_length
492    }
493    
494    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
495        // Stub: Return mock embedding
496        // In production: Run ONNX inference
497        let mock = MockEmbeddingProvider::with_config(self.config.clone());
498        mock.embed(text)
499    }
500}
501
502// ============================================================================
503// Embedding-Enabled Vector Index
504// ============================================================================
505
506/// Vector index with automatic text embedding
507pub struct EmbeddingVectorIndex<V, P> 
508where
509    V: crate::context_query::VectorIndex,
510    P: EmbeddingProvider,
511{
512    /// Underlying vector index
513    index: Arc<V>,
514    
515    /// Embedding provider
516    provider: Arc<P>,
517}
518
519impl<V, P> EmbeddingVectorIndex<V, P>
520where
521    V: crate::context_query::VectorIndex,
522    P: EmbeddingProvider,
523{
524    /// Create a new embedding-enabled vector index
525    pub fn new(index: Arc<V>, provider: Arc<P>) -> Self {
526        Self { index, provider }
527    }
528    
529    /// Search by text (automatically generates embedding)
530    pub fn search_text(
531        &self,
532        collection: &str,
533        text: &str,
534        k: usize,
535        min_score: Option<f32>,
536    ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
537        // Generate embedding
538        let embedding = self.provider.embed(text)
539            .map_err(|e| e.to_string())?;
540        
541        // Search by embedding
542        self.index.search_by_embedding(collection, &embedding, k, min_score)
543    }
544    
545    /// Search by embedding (pass-through)
546    pub fn search_embedding(
547        &self,
548        collection: &str,
549        embedding: &[f32],
550        k: usize,
551        min_score: Option<f32>,
552    ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
553        // Validate dimension
554        if embedding.len() != self.provider.dimension() {
555            return Err(format!(
556                "Embedding dimension mismatch: expected {}, got {}",
557                self.provider.dimension(),
558                embedding.len()
559            ));
560        }
561        
562        self.index.search_by_embedding(collection, embedding, k, min_score)
563    }
564    
565    /// Get the embedding provider
566    pub fn provider(&self) -> &Arc<P> {
567        &self.provider
568    }
569    
570    /// Get the underlying index
571    pub fn index(&self) -> &Arc<V> {
572        &self.index
573    }
574}
575
576impl<V, P> crate::context_query::VectorIndex for EmbeddingVectorIndex<V, P>
577where
578    V: crate::context_query::VectorIndex,
579    P: EmbeddingProvider,
580{
581    fn search_by_embedding(
582        &self,
583        collection: &str,
584        embedding: &[f32],
585        k: usize,
586        min_score: Option<f32>,
587    ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
588        self.search_embedding(collection, embedding, k, min_score)
589    }
590    
591    fn search_by_text(
592        &self,
593        collection: &str,
594        text: &str,
595        k: usize,
596        min_score: Option<f32>,
597    ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
598        self.search_text(collection, text, k, min_score)
599    }
600    
601    fn stats(&self, collection: &str) -> Option<crate::context_query::VectorIndexStats> {
602        self.index.stats(collection)
603    }
604}
605
606// ============================================================================
607// Convenience Functions
608// ============================================================================
609
610/// Create a cached mock embedding provider for testing
611pub fn create_mock_provider(dimension: usize, cache_size: usize) -> CachedEmbeddingProvider<MockEmbeddingProvider> {
612    let mock = MockEmbeddingProvider::new(dimension);
613    CachedEmbeddingProvider::new(mock, cache_size)
614}
615
616/// Create an embedding-enabled vector index with mock provider
617pub fn create_embedding_index<V: crate::context_query::VectorIndex>(
618    index: Arc<V>,
619    dimension: usize,
620) -> EmbeddingVectorIndex<V, CachedEmbeddingProvider<MockEmbeddingProvider>> {
621    let provider = Arc::new(create_mock_provider(dimension, 10_000));
622    EmbeddingVectorIndex::new(index, provider)
623}
624
625// ============================================================================
626// Tests
627// ============================================================================
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632    
633    #[test]
634    fn test_mock_embedding_deterministic() {
635        let provider = MockEmbeddingProvider::new(384);
636        
637        let emb1 = provider.embed("hello world").unwrap();
638        let emb2 = provider.embed("hello world").unwrap();
639        
640        assert_eq!(emb1, emb2);
641        assert_eq!(emb1.len(), 384);
642    }
643    
644    #[test]
645    fn test_mock_embedding_different_texts() {
646        let provider = MockEmbeddingProvider::new(384);
647        
648        let emb1 = provider.embed("hello").unwrap();
649        let emb2 = provider.embed("world").unwrap();
650        
651        assert_ne!(emb1, emb2);
652    }
653    
654    #[test]
655    fn test_cached_provider() {
656        let mock = MockEmbeddingProvider::new(128);
657        let cached = CachedEmbeddingProvider::new(mock, 100);
658        
659        // First call - miss
660        let _ = cached.embed("test text").unwrap();
661        assert_eq!(cached.stats().hits.load(std::sync::atomic::Ordering::Relaxed), 0);
662        assert_eq!(cached.stats().misses.load(std::sync::atomic::Ordering::Relaxed), 1);
663        
664        // Second call - hit
665        let _ = cached.embed("test text").unwrap();
666        assert_eq!(cached.stats().hits.load(std::sync::atomic::Ordering::Relaxed), 1);
667        assert_eq!(cached.stats().misses.load(std::sync::atomic::Ordering::Relaxed), 1);
668        
669        assert!(cached.stats().hit_rate() > 0.4);
670    }
671    
672    #[test]
673    fn test_batch_embedding() {
674        let mock = MockEmbeddingProvider::new(128);
675        let cached = CachedEmbeddingProvider::new(mock, 100);
676        
677        let texts = vec!["hello", "world", "test"];
678        let embeddings = cached.embed_batch(&texts).unwrap();
679        
680        assert_eq!(embeddings.len(), 3);
681        for emb in &embeddings {
682            assert_eq!(emb.len(), 128);
683        }
684    }
685    
686    #[test]
687    fn test_normalization() {
688        let provider = MockEmbeddingProvider::new(3);
689        let emb = provider.embed("test").unwrap();
690        
691        // Check L2 norm is approximately 1
692        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
693        assert!((norm - 1.0).abs() < 1e-5);
694    }
695    
696    #[test]
697    fn test_text_too_long() {
698        let config = EmbeddingConfig {
699            max_length: 10,
700            ..Default::default()
701        };
702        let provider = MockEmbeddingProvider::with_config(config);
703        
704        let result = provider.embed("this is a very long text that exceeds the limit");
705        assert!(matches!(result, Err(EmbeddingError::TextTooLong { .. })));
706    }
707}