Skip to main content

sochdb_query/
embedding_provider.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Automatic Embedding Generation (Task 2)
19//!
20//! This module provides colocated embedding resolution for text-to-vector conversion.
21//! It enables first-class text search by automatically generating embeddings.
22//!
23//! ## Design
24//!
25//! ```text
26//! search_text(collection, text, k)
27//!     │
28//!     ▼
29//! ┌─────────────────┐
30//! │ EmbeddingProvider │
31//! │  ├─ LRU Cache    │
32//! │  └─ ONNX Runtime │
33//! └─────────────────┘
34//!     │
35//!     ▼
36//! search_by_embedding(collection, embedding, k)
37//! ```
38//!
39//! ## Providers
40//!
41//! - `LocalProvider`: Uses FastEmbed/ONNX for offline embedding
42//! - `CachedProvider`: LRU cache wrapper for any provider
43//! - `MockProvider`: For testing
44//!
45//! ## Complexity
46//!
47//! - Embedding generation: O(n) where n = text length (transformer inference)
48//! - Cache lookup: O(1) expected (hash-based LRU)
49//! - Batch embedding: O(k) compute with ~O(1) ONNX session overhead
50
51use moka::sync::Cache;
52use std::sync::Arc;
53
54// ============================================================================
55// Embedding Provider Trait
56// ============================================================================
57
58/// Error type for embedding operations
59#[derive(Debug, Clone)]
60pub enum EmbeddingError {
61    /// Model not loaded or unavailable
62    ModelNotAvailable(String),
63    /// Text too long for model
64    TextTooLong { max_length: usize, actual: usize },
65    /// Dimension mismatch
66    DimensionMismatch { expected: usize, actual: usize },
67    /// Provider error
68    ProviderError(String),
69    /// Cache error
70    CacheError(String),
71}
72
73impl std::fmt::Display for EmbeddingError {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        match self {
76            Self::ModelNotAvailable(model) => write!(f, "Embedding model not available: {}", model),
77            Self::TextTooLong { max_length, actual } => {
78                write!(f, "Text too long: {} > {} max", actual, max_length)
79            }
80            Self::DimensionMismatch { expected, actual } => {
81                write!(
82                    f,
83                    "Dimension mismatch: expected {}, got {}",
84                    expected, actual
85                )
86            }
87            Self::ProviderError(msg) => write!(f, "Provider error: {}", msg),
88            Self::CacheError(msg) => write!(f, "Cache error: {}", msg),
89        }
90    }
91}
92
93impl std::error::Error for EmbeddingError {}
94
95/// Result type for embedding operations
96pub type EmbeddingResult<T> = Result<T, EmbeddingError>;
97
98/// Embedding provider trait
99pub trait EmbeddingProvider: Send + Sync {
100    /// Get the model name
101    fn model_name(&self) -> &str;
102
103    /// Get the embedding dimension
104    fn dimension(&self) -> usize;
105
106    /// Maximum text length (in characters or tokens)
107    fn max_length(&self) -> usize;
108
109    /// Generate embedding for a single text
110    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>>;
111
112    /// Generate embeddings for multiple texts (batch)
113    fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
114        // Default implementation: sequential embedding
115        texts.iter().map(|t| self.embed(t)).collect()
116    }
117
118    /// Normalize an embedding vector (L2 normalization)
119    fn normalize(&self, embedding: &mut [f32]) {
120        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
121        if norm > 1e-10 {
122            for x in embedding.iter_mut() {
123                *x /= norm;
124            }
125        }
126    }
127}
128
129// ============================================================================
130// Embedding Configuration
131// ============================================================================
132
133/// Configuration for embedding providers
134#[derive(Debug, Clone)]
135pub struct EmbeddingConfig {
136    /// Model identifier (e.g., "all-MiniLM-L6-v2")
137    pub model: String,
138
139    /// Model path (for local ONNX models)
140    pub model_path: Option<String>,
141
142    /// Embedding dimension
143    pub dimension: usize,
144
145    /// Maximum text length
146    pub max_length: usize,
147
148    /// Whether to normalize embeddings
149    pub normalize: bool,
150
151    /// Batch size for embedding generation
152    pub batch_size: usize,
153
154    /// Cache size (number of embeddings to cache)
155    pub cache_size: usize,
156
157    /// Cache TTL in seconds (0 = no expiry)
158    pub cache_ttl_secs: u64,
159}
160
161impl Default for EmbeddingConfig {
162    fn default() -> Self {
163        Self {
164            model: "all-MiniLM-L6-v2".to_string(),
165            model_path: None,
166            dimension: 384, // MiniLM dimension
167            max_length: 512,
168            normalize: true,
169            batch_size: 32,
170            cache_size: 10_000,
171            cache_ttl_secs: 3600, // 1 hour
172        }
173    }
174}
175
176impl EmbeddingConfig {
177    /// Create config for sentence-transformers models
178    pub fn sentence_transformer(model: &str) -> Self {
179        let dimension = match model {
180            "all-MiniLM-L6-v2" => 384,
181            "all-MiniLM-L12-v2" => 384,
182            "all-mpnet-base-v2" => 768,
183            "paraphrase-MiniLM-L6-v2" => 384,
184            "multi-qa-MiniLM-L6-cos-v1" => 384,
185            _ => 384, // Default
186        };
187
188        Self {
189            model: model.to_string(),
190            dimension,
191            ..Default::default()
192        }
193    }
194
195    /// Create config for OpenAI-compatible models
196    pub fn openai(model: &str) -> Self {
197        let dimension = match model {
198            "text-embedding-ada-002" => 1536,
199            "text-embedding-3-small" => 1536,
200            "text-embedding-3-large" => 3072,
201            _ => 1536,
202        };
203
204        Self {
205            model: model.to_string(),
206            dimension,
207            max_length: 8192,
208            ..Default::default()
209        }
210    }
211}
212
213// ============================================================================
214// Mock Embedding Provider (for testing)
215// ============================================================================
216
217/// Mock embedding provider for testing
218pub struct MockEmbeddingProvider {
219    config: EmbeddingConfig,
220    /// Deterministic embeddings based on text hash
221    use_hash: bool,
222}
223
224impl MockEmbeddingProvider {
225    /// Create a new mock provider
226    pub fn new(dimension: usize) -> Self {
227        Self {
228            config: EmbeddingConfig {
229                model: "mock".to_string(),
230                dimension,
231                ..Default::default()
232            },
233            use_hash: true,
234        }
235    }
236
237    /// Create with custom config
238    pub fn with_config(config: EmbeddingConfig) -> Self {
239        Self {
240            config,
241            use_hash: true,
242        }
243    }
244
245    /// Generate a deterministic embedding from text
246    fn hash_embed(&self, text: &str) -> Vec<f32> {
247        use std::collections::hash_map::DefaultHasher;
248        use std::hash::{Hash, Hasher};
249
250        let mut embedding = Vec::with_capacity(self.config.dimension);
251
252        // Generate pseudo-random values based on text hash
253        for i in 0..self.config.dimension {
254            let mut hasher = DefaultHasher::new();
255            text.hash(&mut hasher);
256            i.hash(&mut hasher);
257            let hash = hasher.finish();
258
259            // Convert to f32 in range [-1, 1]
260            let value = ((hash as f64) / (u64::MAX as f64) * 2.0 - 1.0) as f32;
261            embedding.push(value);
262        }
263
264        embedding
265    }
266}
267
268impl EmbeddingProvider for MockEmbeddingProvider {
269    fn model_name(&self) -> &str {
270        &self.config.model
271    }
272
273    fn dimension(&self) -> usize {
274        self.config.dimension
275    }
276
277    fn max_length(&self) -> usize {
278        self.config.max_length
279    }
280
281    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
282        if text.len() > self.config.max_length {
283            return Err(EmbeddingError::TextTooLong {
284                max_length: self.config.max_length,
285                actual: text.len(),
286            });
287        }
288
289        let mut embedding = if self.use_hash {
290            self.hash_embed(text)
291        } else {
292            vec![0.0; self.config.dimension]
293        };
294
295        if self.config.normalize {
296            self.normalize(&mut embedding);
297        }
298
299        Ok(embedding)
300    }
301}
302
303// ============================================================================
304// Cached Embedding Provider
305// ============================================================================
306
307/// LRU-cached embedding provider wrapper
308pub struct CachedEmbeddingProvider<P: EmbeddingProvider> {
309    /// Inner provider
310    inner: P,
311
312    /// LRU cache: text hash -> embedding
313    cache: Cache<u64, Vec<f32>>,
314
315    /// Cache statistics
316    stats: Arc<CacheStats>,
317}
318
319/// Cache statistics
320#[derive(Debug, Default)]
321pub struct CacheStats {
322    /// Number of cache hits
323    pub hits: std::sync::atomic::AtomicUsize,
324    /// Number of cache misses
325    pub misses: std::sync::atomic::AtomicUsize,
326    /// Number of embeddings cached
327    pub size: std::sync::atomic::AtomicUsize,
328}
329
330impl CacheStats {
331    /// Get hit rate
332    pub fn hit_rate(&self) -> f64 {
333        let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
334        let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
335        let total = hits + misses;
336        if total == 0 {
337            0.0
338        } else {
339            hits as f64 / total as f64
340        }
341    }
342}
343
344impl<P: EmbeddingProvider> CachedEmbeddingProvider<P> {
345    /// Create a new cached provider
346    pub fn new(inner: P, cache_size: usize) -> Self {
347        Self {
348            inner,
349            cache: Cache::new(cache_size as u64),
350            stats: Arc::new(CacheStats::default()),
351        }
352    }
353
354    /// Create with TTL
355    pub fn with_ttl(inner: P, cache_size: usize, ttl_secs: u64) -> Self {
356        let cache = Cache::builder()
357            .max_capacity(cache_size as u64)
358            .time_to_live(std::time::Duration::from_secs(ttl_secs))
359            .build();
360
361        Self {
362            inner,
363            cache,
364            stats: Arc::new(CacheStats::default()),
365        }
366    }
367
368    /// Get cache statistics
369    pub fn stats(&self) -> &Arc<CacheStats> {
370        &self.stats
371    }
372
373    /// Compute hash for cache key
374    fn text_hash(text: &str) -> u64 {
375        use std::collections::hash_map::DefaultHasher;
376        use std::hash::{Hash, Hasher};
377
378        let mut hasher = DefaultHasher::new();
379        text.hash(&mut hasher);
380        hasher.finish()
381    }
382}
383
384impl<P: EmbeddingProvider> EmbeddingProvider for CachedEmbeddingProvider<P> {
385    fn model_name(&self) -> &str {
386        self.inner.model_name()
387    }
388
389    fn dimension(&self) -> usize {
390        self.inner.dimension()
391    }
392
393    fn max_length(&self) -> usize {
394        self.inner.max_length()
395    }
396
397    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
398        let hash = Self::text_hash(text);
399
400        // Check cache
401        if let Some(cached) = self.cache.get(&hash) {
402            self.stats
403                .hits
404                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
405            return Ok(cached);
406        }
407
408        self.stats
409            .misses
410            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
411
412        // Generate embedding
413        let embedding = self.inner.embed(text)?;
414
415        // Cache result
416        self.cache.insert(hash, embedding.clone());
417        self.stats
418            .size
419            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
420
421        Ok(embedding)
422    }
423
424    fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
425        let mut results = Vec::with_capacity(texts.len());
426        let mut uncached: Vec<(usize, &str)> = Vec::new();
427
428        // Check cache for each text
429        for (i, text) in texts.iter().enumerate() {
430            let hash = Self::text_hash(text);
431            if let Some(cached) = self.cache.get(&hash) {
432                self.stats
433                    .hits
434                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
435                results.push((i, cached));
436            } else {
437                self.stats
438                    .misses
439                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
440                uncached.push((i, *text));
441            }
442        }
443
444        // Generate embeddings for uncached texts
445        if !uncached.is_empty() {
446            let uncached_texts: Vec<&str> = uncached.iter().map(|(_, t)| *t).collect();
447            let embeddings = self.inner.embed_batch(&uncached_texts)?;
448
449            for ((i, text), embedding) in uncached.iter().zip(embeddings.into_iter()) {
450                let hash = Self::text_hash(text);
451                self.cache.insert(hash, embedding.clone());
452                self.stats
453                    .size
454                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
455                results.push((*i, embedding));
456            }
457        }
458
459        // Sort by original index
460        results.sort_by_key(|(i, _)| *i);
461        Ok(results.into_iter().map(|(_, e)| e).collect())
462    }
463}
464
465// ============================================================================
466// Local ONNX Provider (Stub)
467// ============================================================================
468
469/// Local ONNX-based embedding provider
470///
471/// This is a stub implementation. In production, this would use:
472/// - ort (ONNX Runtime) for model inference
473/// - fastembed-rs for pre-packaged models
474/// - tokenizers for text preprocessing
475#[derive(Debug)]
476pub struct LocalOnnxProvider {
477    config: EmbeddingConfig,
478    /// Model weights (placeholder)
479    #[allow(dead_code)]
480    model_loaded: bool,
481}
482
483impl LocalOnnxProvider {
484    /// Create a new local ONNX provider
485    pub fn new(config: EmbeddingConfig) -> EmbeddingResult<Self> {
486        // In production: Load ONNX model from path
487        Ok(Self {
488            config,
489            model_loaded: false,
490        })
491    }
492
493    /// Load a pre-trained model by name
494    pub fn load_pretrained(model_name: &str) -> EmbeddingResult<Self> {
495        let config = EmbeddingConfig::sentence_transformer(model_name);
496        Self::new(config)
497    }
498}
499
500impl EmbeddingProvider for LocalOnnxProvider {
501    fn model_name(&self) -> &str {
502        &self.config.model
503    }
504
505    fn dimension(&self) -> usize {
506        self.config.dimension
507    }
508
509    fn max_length(&self) -> usize {
510        self.config.max_length
511    }
512
513    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
514        // Stub: Return mock embedding
515        // In production: Run ONNX inference
516        let mock = MockEmbeddingProvider::with_config(self.config.clone());
517        mock.embed(text)
518    }
519}
520
521// ============================================================================
522// Embedding-Enabled Vector Index
523// ============================================================================
524
525/// Vector index with automatic text embedding
526pub struct EmbeddingVectorIndex<V, P>
527where
528    V: crate::context_query::VectorIndex,
529    P: EmbeddingProvider,
530{
531    /// Underlying vector index
532    index: Arc<V>,
533
534    /// Embedding provider
535    provider: Arc<P>,
536}
537
538impl<V, P> EmbeddingVectorIndex<V, P>
539where
540    V: crate::context_query::VectorIndex,
541    P: EmbeddingProvider,
542{
543    /// Create a new embedding-enabled vector index
544    pub fn new(index: Arc<V>, provider: Arc<P>) -> Self {
545        Self { index, provider }
546    }
547
548    /// Search by text (automatically generates embedding)
549    pub fn search_text(
550        &self,
551        collection: &str,
552        text: &str,
553        k: usize,
554        min_score: Option<f32>,
555    ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
556        // Generate embedding
557        let embedding = self.provider.embed(text).map_err(|e| e.to_string())?;
558
559        // Search by embedding
560        self.index
561            .search_by_embedding(collection, &embedding, k, min_score)
562    }
563
564    /// Search by embedding (pass-through)
565    pub fn search_embedding(
566        &self,
567        collection: &str,
568        embedding: &[f32],
569        k: usize,
570        min_score: Option<f32>,
571    ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
572        // Validate dimension
573        if embedding.len() != self.provider.dimension() {
574            return Err(format!(
575                "Embedding dimension mismatch: expected {}, got {}",
576                self.provider.dimension(),
577                embedding.len()
578            ));
579        }
580
581        self.index
582            .search_by_embedding(collection, embedding, k, min_score)
583    }
584
585    /// Get the embedding provider
586    pub fn provider(&self) -> &Arc<P> {
587        &self.provider
588    }
589
590    /// Get the underlying index
591    pub fn index(&self) -> &Arc<V> {
592        &self.index
593    }
594}
595
596impl<V, P> crate::context_query::VectorIndex for EmbeddingVectorIndex<V, P>
597where
598    V: crate::context_query::VectorIndex,
599    P: EmbeddingProvider,
600{
601    fn search_by_embedding(
602        &self,
603        collection: &str,
604        embedding: &[f32],
605        k: usize,
606        min_score: Option<f32>,
607    ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
608        self.search_embedding(collection, embedding, k, min_score)
609    }
610
611    fn search_by_text(
612        &self,
613        collection: &str,
614        text: &str,
615        k: usize,
616        min_score: Option<f32>,
617    ) -> Result<Vec<crate::context_query::VectorSearchResult>, String> {
618        self.search_text(collection, text, k, min_score)
619    }
620
621    fn stats(&self, collection: &str) -> Option<crate::context_query::VectorIndexStats> {
622        self.index.stats(collection)
623    }
624}
625
626// ============================================================================
627// Convenience Functions
628// ============================================================================
629
630/// Create a cached mock embedding provider for testing
631pub fn create_mock_provider(
632    dimension: usize,
633    cache_size: usize,
634) -> CachedEmbeddingProvider<MockEmbeddingProvider> {
635    let mock = MockEmbeddingProvider::new(dimension);
636    CachedEmbeddingProvider::new(mock, cache_size)
637}
638
639/// Create an embedding-enabled vector index with mock provider
640pub fn create_embedding_index<V: crate::context_query::VectorIndex>(
641    index: Arc<V>,
642    dimension: usize,
643) -> EmbeddingVectorIndex<V, CachedEmbeddingProvider<MockEmbeddingProvider>> {
644    let provider = Arc::new(create_mock_provider(dimension, 10_000));
645    EmbeddingVectorIndex::new(index, provider)
646}
647
648// ============================================================================
649// Tests
650// ============================================================================
651
652#[cfg(test)]
653mod tests {
654    use super::*;
655
656    #[test]
657    fn test_mock_embedding_deterministic() {
658        let provider = MockEmbeddingProvider::new(384);
659
660        let emb1 = provider.embed("hello world").unwrap();
661        let emb2 = provider.embed("hello world").unwrap();
662
663        assert_eq!(emb1, emb2);
664        assert_eq!(emb1.len(), 384);
665    }
666
667    #[test]
668    fn test_mock_embedding_different_texts() {
669        let provider = MockEmbeddingProvider::new(384);
670
671        let emb1 = provider.embed("hello").unwrap();
672        let emb2 = provider.embed("world").unwrap();
673
674        assert_ne!(emb1, emb2);
675    }
676
677    #[test]
678    fn test_cached_provider() {
679        let mock = MockEmbeddingProvider::new(128);
680        let cached = CachedEmbeddingProvider::new(mock, 100);
681
682        // First call - miss
683        let _ = cached.embed("test text").unwrap();
684        assert_eq!(
685            cached
686                .stats()
687                .hits
688                .load(std::sync::atomic::Ordering::Relaxed),
689            0
690        );
691        assert_eq!(
692            cached
693                .stats()
694                .misses
695                .load(std::sync::atomic::Ordering::Relaxed),
696            1
697        );
698
699        // Second call - hit
700        let _ = cached.embed("test text").unwrap();
701        assert_eq!(
702            cached
703                .stats()
704                .hits
705                .load(std::sync::atomic::Ordering::Relaxed),
706            1
707        );
708        assert_eq!(
709            cached
710                .stats()
711                .misses
712                .load(std::sync::atomic::Ordering::Relaxed),
713            1
714        );
715
716        assert!(cached.stats().hit_rate() > 0.4);
717    }
718
719    #[test]
720    fn test_batch_embedding() {
721        let mock = MockEmbeddingProvider::new(128);
722        let cached = CachedEmbeddingProvider::new(mock, 100);
723
724        let texts = vec!["hello", "world", "test"];
725        let embeddings = cached.embed_batch(&texts).unwrap();
726
727        assert_eq!(embeddings.len(), 3);
728        for emb in &embeddings {
729            assert_eq!(emb.len(), 128);
730        }
731    }
732
733    #[test]
734    fn test_normalization() {
735        let provider = MockEmbeddingProvider::new(3);
736        let emb = provider.embed("test").unwrap();
737
738        // Check L2 norm is approximately 1
739        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
740        assert!((norm - 1.0).abs() < 1e-5);
741    }
742
743    #[test]
744    fn test_text_too_long() {
745        let config = EmbeddingConfig {
746            max_length: 10,
747            ..Default::default()
748        };
749        let provider = MockEmbeddingProvider::with_config(config);
750
751        let result = provider.embed("this is a very long text that exceeds the limit");
752        assert!(matches!(result, Err(EmbeddingError::TextTooLong { .. })));
753    }
754}