Skip to main content

sochdb_query/
embedding_provider.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Automatic Embedding Generation (Task 2)
19//!
20//! This module provides colocated embedding resolution for text-to-vector conversion.
21//! It enables first-class text search by automatically generating embeddings.
22//!
23//! ## Design
24//!
25//! ```text
26//! search_text(collection, text, k)
27//!     │
28//!     ▼
29//! ┌─────────────────┐
30//! │ EmbeddingProvider │
31//! │  ├─ LRU Cache    │
32//! │  └─ ONNX Runtime │
33//! └─────────────────┘
34//!     │
35//!     ▼
36//! search_by_embedding(collection, embedding, k)
37//! ```
38//!
39//! ## Providers
40//!
41//! - `LocalProvider`: Uses FastEmbed/ONNX for offline embedding
42//! - `CachedProvider`: LRU cache wrapper for any provider
43//! - `MockProvider`: For testing
44//!
45//! ## Complexity
46//!
47//! - Embedding generation: O(n) where n = text length (transformer inference)
48//! - Cache lookup: O(1) expected (hash-based LRU)
49//! - Batch embedding: O(k) compute with ~O(1) ONNX session overhead
50
51use std::sync::Arc;
52use moka::sync::Cache;
53
54// ============================================================================
55// Embedding Provider Trait
56// ============================================================================
57
58/// Error type for embedding operations
59#[derive(Debug, Clone)]
60pub enum EmbeddingError {
61    /// Model not loaded or unavailable
62    ModelNotAvailable(String),
63    /// Text too long for model
64    TextTooLong { max_length: usize, actual: usize },
65    /// Dimension mismatch
66    DimensionMismatch { expected: usize, actual: usize },
67    /// Provider error
68    ProviderError(String),
69    /// Cache error
70    CacheError(String),
71}
72
73impl std::fmt::Display for EmbeddingError {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        match self {
76            Self::ModelNotAvailable(model) => write!(f, "Embedding model not available: {}", model),
77            Self::TextTooLong { max_length, actual } => {
78                write!(f, "Text too long: {} > {} max", actual, max_length)
79            }
80            Self::DimensionMismatch { expected, actual } => {
81                write!(f, "Dimension mismatch: expected {}, got {}", expected, actual)
82            }
83            Self::ProviderError(msg) => write!(f, "Provider error: {}", msg),
84            Self::CacheError(msg) => write!(f, "Cache error: {}", msg),
85        }
86    }
87}
88
89impl std::error::Error for EmbeddingError {}
90
91/// Result type for embedding operations
92pub type EmbeddingResult<T> = Result<T, EmbeddingError>;
93
94/// Embedding provider trait
95pub trait EmbeddingProvider: Send + Sync {
96    /// Get the model name
97    fn model_name(&self) -> &str;
98    
99    /// Get the embedding dimension
100    fn dimension(&self) -> usize;
101    
102    /// Maximum text length (in characters or tokens)
103    fn max_length(&self) -> usize;
104    
105    /// Generate embedding for a single text
106    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>>;
107    
108    /// Generate embeddings for multiple texts (batch)
109    fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
110        // Default implementation: sequential embedding
111        texts.iter().map(|t| self.embed(t)).collect()
112    }
113    
114    /// Normalize an embedding vector (L2 normalization)
115    fn normalize(&self, embedding: &mut [f32]) {
116        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
117        if norm > 1e-10 {
118            for x in embedding.iter_mut() {
119                *x /= norm;
120            }
121        }
122    }
123}
124
125// ============================================================================
126// Embedding Configuration
127// ============================================================================
128
129/// Configuration for embedding providers
130#[derive(Debug, Clone)]
131pub struct EmbeddingConfig {
132    /// Model identifier (e.g., "all-MiniLM-L6-v2")
133    pub model: String,
134    
135    /// Model path (for local ONNX models)
136    pub model_path: Option<String>,
137    
138    /// Embedding dimension
139    pub dimension: usize,
140    
141    /// Maximum text length
142    pub max_length: usize,
143    
144    /// Whether to normalize embeddings
145    pub normalize: bool,
146    
147    /// Batch size for embedding generation
148    pub batch_size: usize,
149    
150    /// Cache size (number of embeddings to cache)
151    pub cache_size: usize,
152    
153    /// Cache TTL in seconds (0 = no expiry)
154    pub cache_ttl_secs: u64,
155}
156
157impl Default for EmbeddingConfig {
158    fn default() -> Self {
159        Self {
160            model: "all-MiniLM-L6-v2".to_string(),
161            model_path: None,
162            dimension: 384, // MiniLM dimension
163            max_length: 512,
164            normalize: true,
165            batch_size: 32,
166            cache_size: 10_000,
167            cache_ttl_secs: 3600, // 1 hour
168        }
169    }
170}
171
172impl EmbeddingConfig {
173    /// Create config for sentence-transformers models
174    pub fn sentence_transformer(model: &str) -> Self {
175        let dimension = match model {
176            "all-MiniLM-L6-v2" => 384,
177            "all-MiniLM-L12-v2" => 384,
178            "all-mpnet-base-v2" => 768,
179            "paraphrase-MiniLM-L6-v2" => 384,
180            "multi-qa-MiniLM-L6-cos-v1" => 384,
181            _ => 384, // Default
182        };
183        
184        Self {
185            model: model.to_string(),
186            dimension,
187            ..Default::default()
188        }
189    }
190    
191    /// Create config for OpenAI-compatible models
192    pub fn openai(model: &str) -> Self {
193        let dimension = match model {
194            "text-embedding-ada-002" => 1536,
195            "text-embedding-3-small" => 1536,
196            "text-embedding-3-large" => 3072,
197            _ => 1536,
198        };
199        
200        Self {
201            model: model.to_string(),
202            dimension,
203            max_length: 8192,
204            ..Default::default()
205        }
206    }
207}
208
209// ============================================================================
210// Mock Embedding Provider (for testing)
211// ============================================================================
212
213/// Mock embedding provider for testing
214pub struct MockEmbeddingProvider {
215    config: EmbeddingConfig,
216    /// Deterministic embeddings based on text hash
217    use_hash: bool,
218}
219
220impl MockEmbeddingProvider {
221    /// Create a new mock provider
222    pub fn new(dimension: usize) -> Self {
223        Self {
224            config: EmbeddingConfig {
225                model: "mock".to_string(),
226                dimension,
227                ..Default::default()
228            },
229            use_hash: true,
230        }
231    }
232    
233    /// Create with custom config
234    pub fn with_config(config: EmbeddingConfig) -> Self {
235        Self {
236            config,
237            use_hash: true,
238        }
239    }
240    
241    /// Generate a deterministic embedding from text
242    fn hash_embed(&self, text: &str) -> Vec<f32> {
243        use std::hash::{Hash, Hasher};
244        use std::collections::hash_map::DefaultHasher;
245        
246        let mut embedding = Vec::with_capacity(self.config.dimension);
247        
248        // Generate pseudo-random values based on text hash
249        for i in 0..self.config.dimension {
250            let mut hasher = DefaultHasher::new();
251            text.hash(&mut hasher);
252            i.hash(&mut hasher);
253            let hash = hasher.finish();
254            
255            // Convert to f32 in range [-1, 1]
256            let value = ((hash as f64) / (u64::MAX as f64) * 2.0 - 1.0) as f32;
257            embedding.push(value);
258        }
259        
260        embedding
261    }
262}
263
264impl EmbeddingProvider for MockEmbeddingProvider {
265    fn model_name(&self) -> &str {
266        &self.config.model
267    }
268    
269    fn dimension(&self) -> usize {
270        self.config.dimension
271    }
272    
273    fn max_length(&self) -> usize {
274        self.config.max_length
275    }
276    
277    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
278        if text.len() > self.config.max_length {
279            return Err(EmbeddingError::TextTooLong {
280                max_length: self.config.max_length,
281                actual: text.len(),
282            });
283        }
284        
285        let mut embedding = if self.use_hash {
286            self.hash_embed(text)
287        } else {
288            vec![0.0; self.config.dimension]
289        };
290        
291        if self.config.normalize {
292            self.normalize(&mut embedding);
293        }
294        
295        Ok(embedding)
296    }
297}
298
299// ============================================================================
300// Cached Embedding Provider
301// ============================================================================
302
303/// LRU-cached embedding provider wrapper
304pub struct CachedEmbeddingProvider<P: EmbeddingProvider> {
305    /// Inner provider
306    inner: P,
307    
308    /// LRU cache: text hash -> embedding
309    cache: Cache<u64, Vec<f32>>,
310    
311    /// Cache statistics
312    stats: Arc<CacheStats>,
313}
314
315/// Cache statistics
316#[derive(Debug, Default)]
317pub struct CacheStats {
318    /// Number of cache hits
319    pub hits: std::sync::atomic::AtomicUsize,
320    /// Number of cache misses
321    pub misses: std::sync::atomic::AtomicUsize,
322    /// Number of embeddings cached
323    pub size: std::sync::atomic::AtomicUsize,
324}
325
326impl CacheStats {
327    /// Get hit rate
328    pub fn hit_rate(&self) -> f64 {
329        let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
330        let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
331        let total = hits + misses;
332        if total == 0 {
333            0.0
334        } else {
335            hits as f64 / total as f64
336        }
337    }
338}
339
340impl<P: EmbeddingProvider> CachedEmbeddingProvider<P> {
341    /// Create a new cached provider
342    pub fn new(inner: P, cache_size: usize) -> Self {
343        Self {
344            inner,
345            cache: Cache::new(cache_size as u64),
346            stats: Arc::new(CacheStats::default()),
347        }
348    }
349    
350    /// Create with TTL
351    pub fn with_ttl(inner: P, cache_size: usize, ttl_secs: u64) -> Self {
352        let cache = Cache::builder()
353            .max_capacity(cache_size as u64)
354            .time_to_live(std::time::Duration::from_secs(ttl_secs))
355            .build();
356        
357        Self {
358            inner,
359            cache,
360            stats: Arc::new(CacheStats::default()),
361        }
362    }
363    
364    /// Get cache statistics
365    pub fn stats(&self) -> &Arc<CacheStats> {
366        &self.stats
367    }
368    
369    /// Compute hash for cache key
370    fn text_hash(text: &str) -> u64 {
371        use std::hash::{Hash, Hasher};
372        use std::collections::hash_map::DefaultHasher;
373        
374        let mut hasher = DefaultHasher::new();
375        text.hash(&mut hasher);
376        hasher.finish()
377    }
378}
379
380impl<P: EmbeddingProvider> EmbeddingProvider for CachedEmbeddingProvider<P> {
381    fn model_name(&self) -> &str {
382        self.inner.model_name()
383    }
384    
385    fn dimension(&self) -> usize {
386        self.inner.dimension()
387    }
388    
389    fn max_length(&self) -> usize {
390        self.inner.max_length()
391    }
392    
393    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
394        let hash = Self::text_hash(text);
395        
396        // Check cache
397        if let Some(cached) = self.cache.get(&hash) {
398            self.stats.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
399            return Ok(cached);
400        }
401        
402        self.stats.misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
403        
404        // Generate embedding
405        let embedding = self.inner.embed(text)?;
406        
407        // Cache result
408        self.cache.insert(hash, embedding.clone());
409        self.stats.size.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
410        
411        Ok(embedding)
412    }
413    
414    fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
415        let mut results = Vec::with_capacity(texts.len());
416        let mut uncached: Vec<(usize, &str)> = Vec::new();
417        
418        // Check cache for each text
419        for (i, text) in texts.iter().enumerate() {
420            let hash = Self::text_hash(text);
421            if let Some(cached) = self.cache.get(&hash) {
422                self.stats.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
423                results.push((i, cached));
424            } else {
425                self.stats.misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
426                uncached.push((i, *text));
427            }
428        }
429        
430        // Generate embeddings for uncached texts
431        if !uncached.is_empty() {
432            let uncached_texts: Vec<&str> = uncached.iter().map(|(_, t)| *t).collect();
433            let embeddings = self.inner.embed_batch(&uncached_texts)?;
434            
435            for ((i, text), embedding) in uncached.iter().zip(embeddings.into_iter()) {
436                let hash = Self::text_hash(text);
437                self.cache.insert(hash, embedding.clone());
438                self.stats.size.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
439                results.push((*i, embedding));
440            }
441        }
442        
443        // Sort by original index
444        results.sort_by_key(|(i, _)| *i);
445        Ok(results.into_iter().map(|(_, e)| e).collect())
446    }
447}
448
449// ============================================================================
450// Local ONNX Provider (Stub)
451// ============================================================================
452
453/// Local ONNX-based embedding provider
454/// 
455/// This is a stub implementation. In production, this would use:
456/// - ort (ONNX Runtime) for model inference
457/// - fastembed-rs for pre-packaged models
458/// - tokenizers for text preprocessing
459#[derive(Debug)]
460pub struct LocalOnnxProvider {
461    config: EmbeddingConfig,
462    /// Model weights (placeholder)
463    #[allow(dead_code)]
464    model_loaded: bool,
465}
466
467impl LocalOnnxProvider {
468    /// Create a new local ONNX provider
469    pub fn new(config: EmbeddingConfig) -> EmbeddingResult<Self> {
470        // In production: Load ONNX model from path
471        Ok(Self {
472            config,
473            model_loaded: false,
474        })
475    }
476    
477    /// Load a pre-trained model by name
478    pub fn load_pretrained(model_name: &str) -> EmbeddingResult<Self> {
479        let config = EmbeddingConfig::sentence_transformer(model_name);
480        Self::new(config)
481    }
482}
483
484impl EmbeddingProvider for LocalOnnxProvider {
485    fn model_name(&self) -> &str {
486        &self.config.model
487    }
488    
489    fn dimension(&self) -> usize {
490        self.config.dimension
491    }
492    
493    fn max_length(&self) -> usize {
494        self.config.max_length
495    }
496    
497    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
498        // Stub: Return mock embedding
499        // In production: Run ONNX inference
500        let mock = MockEmbeddingProvider::with_config(self.config.clone());
501        mock.embed(text)
502    }
503}
504
505// ============================================================================
506// Embedding-Enabled Vector Index
507// ============================================================================
508
509/// Vector index with automatic text embedding
510pub struct EmbeddingVectorIndex<V, P> 
511where
512    V: crate::context_query::VectorIndex,
513    P: EmbeddingProvider,
514{
515    /// Underlying vector index
516    index: Arc<V>,
517    
518    /// Embedding provider
519    provider: Arc<P>,
520}
521
522impl<V, P> EmbeddingVectorIndex<V, P>
523where
524    V: crate::context_query::VectorIndex,
525    P: EmbeddingProvider,
526{
527    /// Create a new embedding-enabled vector index
528    pub fn new(index: Arc<V>, provider: Arc<P>) -> Self {
529        Self { index, provider }
530    }
531    
532    /// Search by text (automatically generates embedding)
533    pub fn search_text(
534        &self,
535        collection: &str,
536        text: &str,
537        k: usize,
538        min_score: Option<f32>,
539    ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
540        // Generate embedding
541        let embedding = self.provider.embed(text)
542            .map_err(|e| e.to_string())?;
543        
544        // Search by embedding
545        self.index.search_by_embedding(collection, &embedding, k, min_score)
546    }
547    
548    /// Search by embedding (pass-through)
549    pub fn search_embedding(
550        &self,
551        collection: &str,
552        embedding: &[f32],
553        k: usize,
554        min_score: Option<f32>,
555    ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
556        // Validate dimension
557        if embedding.len() != self.provider.dimension() {
558            return Err(format!(
559                "Embedding dimension mismatch: expected {}, got {}",
560                self.provider.dimension(),
561                embedding.len()
562            ));
563        }
564        
565        self.index.search_by_embedding(collection, embedding, k, min_score)
566    }
567    
568    /// Get the embedding provider
569    pub fn provider(&self) -> &Arc<P> {
570        &self.provider
571    }
572    
573    /// Get the underlying index
574    pub fn index(&self) -> &Arc<V> {
575        &self.index
576    }
577}
578
579impl<V, P> crate::context_query::VectorIndex for EmbeddingVectorIndex<V, P>
580where
581    V: crate::context_query::VectorIndex,
582    P: EmbeddingProvider,
583{
584    fn search_by_embedding(
585        &self,
586        collection: &str,
587        embedding: &[f32],
588        k: usize,
589        min_score: Option<f32>,
590    ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
591        self.search_embedding(collection, embedding, k, min_score)
592    }
593    
594    fn search_by_text(
595        &self,
596        collection: &str,
597        text: &str,
598        k: usize,
599        min_score: Option<f32>,
600    ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
601        self.search_text(collection, text, k, min_score)
602    }
603    
604    fn stats(&self, collection: &str) -> Option<crate::context_query::VectorIndexStats> {
605        self.index.stats(collection)
606    }
607}
608
609// ============================================================================
610// Convenience Functions
611// ============================================================================
612
613/// Create a cached mock embedding provider for testing
614pub fn create_mock_provider(dimension: usize, cache_size: usize) -> CachedEmbeddingProvider<MockEmbeddingProvider> {
615    let mock = MockEmbeddingProvider::new(dimension);
616    CachedEmbeddingProvider::new(mock, cache_size)
617}
618
619/// Create an embedding-enabled vector index with mock provider
620pub fn create_embedding_index<V: crate::context_query::VectorIndex>(
621    index: Arc<V>,
622    dimension: usize,
623) -> EmbeddingVectorIndex<V, CachedEmbeddingProvider<MockEmbeddingProvider>> {
624    let provider = Arc::new(create_mock_provider(dimension, 10_000));
625    EmbeddingVectorIndex::new(index, provider)
626}
627
628// ============================================================================
629// Tests
630// ============================================================================
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635    
636    #[test]
637    fn test_mock_embedding_deterministic() {
638        let provider = MockEmbeddingProvider::new(384);
639        
640        let emb1 = provider.embed("hello world").unwrap();
641        let emb2 = provider.embed("hello world").unwrap();
642        
643        assert_eq!(emb1, emb2);
644        assert_eq!(emb1.len(), 384);
645    }
646    
647    #[test]
648    fn test_mock_embedding_different_texts() {
649        let provider = MockEmbeddingProvider::new(384);
650        
651        let emb1 = provider.embed("hello").unwrap();
652        let emb2 = provider.embed("world").unwrap();
653        
654        assert_ne!(emb1, emb2);
655    }
656    
657    #[test]
658    fn test_cached_provider() {
659        let mock = MockEmbeddingProvider::new(128);
660        let cached = CachedEmbeddingProvider::new(mock, 100);
661        
662        // First call - miss
663        let _ = cached.embed("test text").unwrap();
664        assert_eq!(cached.stats().hits.load(std::sync::atomic::Ordering::Relaxed), 0);
665        assert_eq!(cached.stats().misses.load(std::sync::atomic::Ordering::Relaxed), 1);
666        
667        // Second call - hit
668        let _ = cached.embed("test text").unwrap();
669        assert_eq!(cached.stats().hits.load(std::sync::atomic::Ordering::Relaxed), 1);
670        assert_eq!(cached.stats().misses.load(std::sync::atomic::Ordering::Relaxed), 1);
671        
672        assert!(cached.stats().hit_rate() > 0.4);
673    }
674    
675    #[test]
676    fn test_batch_embedding() {
677        let mock = MockEmbeddingProvider::new(128);
678        let cached = CachedEmbeddingProvider::new(mock, 100);
679        
680        let texts = vec!["hello", "world", "test"];
681        let embeddings = cached.embed_batch(&texts).unwrap();
682        
683        assert_eq!(embeddings.len(), 3);
684        for emb in &embeddings {
685            assert_eq!(emb.len(), 128);
686        }
687    }
688    
689    #[test]
690    fn test_normalization() {
691        let provider = MockEmbeddingProvider::new(3);
692        let emb = provider.embed("test").unwrap();
693        
694        // Check L2 norm is approximately 1
695        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
696        assert!((norm - 1.0).abs() < 1e-5);
697    }
698    
699    #[test]
700    fn test_text_too_long() {
701        let config = EmbeddingConfig {
702            max_length: 10,
703            ..Default::default()
704        };
705        let provider = MockEmbeddingProvider::with_config(config);
706        
707        let result = provider.embed("this is a very long text that exceeds the limit");
708        assert!(matches!(result, Err(EmbeddingError::TextTooLong { .. })));
709    }
710}