Skip to main content

trueno_rag/embed/
mod.rs

1//! Embedding generation for RAG pipelines
2
3#[cfg(feature = "nemotron")]
4mod nemotron;
5#[cfg(feature = "nemotron")]
6pub use nemotron::{NemotronConfig, NemotronEmbedder};
7
8use crate::{Chunk, Error, Result};
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11
12/// Pooling strategy for token embeddings
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum PoolingStrategy {
15    /// Use [CLS] token embedding
16    Cls,
17    /// Mean of all token embeddings
18    Mean,
19    /// Mean with attention weighting
20    WeightedMean,
21    /// Last token (for decoder models)
22    LastToken,
23}
24
25impl Default for PoolingStrategy {
26    fn default() -> Self {
27        Self::Mean
28    }
29}
30
31/// Configuration for embedding generation
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct EmbeddingConfig {
34    /// Normalize embeddings to unit length
35    pub normalize: bool,
36    /// Instruction prefix for queries (asymmetric retrieval)
37    pub query_prefix: Option<String>,
38    /// Instruction prefix for documents
39    pub document_prefix: Option<String>,
40    /// Maximum sequence length in tokens
41    pub max_length: usize,
42    /// Pooling strategy
43    pub pooling: PoolingStrategy,
44}
45
46impl Default for EmbeddingConfig {
47    fn default() -> Self {
48        Self {
49            normalize: true,
50            query_prefix: None,
51            document_prefix: None,
52            max_length: 512,
53            pooling: PoolingStrategy::Mean,
54        }
55    }
56}
57
58/// Trait for embedding generation
59#[async_trait]
60pub trait Embedder: Send + Sync {
61    /// Embed a single text
62    fn embed(&self, text: &str) -> Result<Vec<f32>>;
63
64    /// Batch embed multiple texts
65    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
66
67    /// Get embedding dimension
68    fn dimension(&self) -> usize;
69
70    /// Get model identifier
71    fn model_id(&self) -> &str;
72
73    /// Embed a query (may use query prefix)
74    fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
75        self.embed(query)
76    }
77
78    /// Embed a document (may use document prefix)
79    fn embed_document(&self, document: &str) -> Result<Vec<f32>> {
80        self.embed(document)
81    }
82
83    /// Embed chunks and update them in place
84    fn embed_chunks(&self, chunks: &mut [Chunk]) -> Result<()> {
85        let texts: Vec<&str> = chunks.iter().map(|c| c.content.as_str()).collect();
86        let embeddings = self.embed_batch(&texts)?;
87
88        for (chunk, embedding) in chunks.iter_mut().zip(embeddings) {
89            chunk.set_embedding(embedding);
90        }
91
92        Ok(())
93    }
94}
95
96/// Blanket impl so `HybridRetriever<Box<dyn Embedder>>` works without
97/// requiring the caller to know the concrete embedder type at compile time.
98impl Embedder for Box<dyn Embedder> {
99    fn embed(&self, text: &str) -> Result<Vec<f32>> {
100        (**self).embed(text)
101    }
102
103    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
104        (**self).embed_batch(texts)
105    }
106
107    fn dimension(&self) -> usize {
108        (**self).dimension()
109    }
110
111    fn model_id(&self) -> &str {
112        (**self).model_id()
113    }
114
115    fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
116        (**self).embed_query(query)
117    }
118
119    fn embed_document(&self, document: &str) -> Result<Vec<f32>> {
120        (**self).embed_document(document)
121    }
122
123    fn embed_chunks(&self, chunks: &mut [Chunk]) -> Result<()> {
124        (**self).embed_chunks(chunks)
125    }
126}
127
128/// Mock embedder for testing (uses simple hash-based vectors)
129#[derive(Debug, Clone)]
130pub struct MockEmbedder {
131    dimension: usize,
132    model_id: String,
133    config: EmbeddingConfig,
134}
135
136impl MockEmbedder {
137    /// Create a new mock embedder
138    #[must_use]
139    pub fn new(dimension: usize) -> Self {
140        Self {
141            dimension,
142            model_id: "mock-embedder".to_string(),
143            config: EmbeddingConfig::default(),
144        }
145    }
146
147    /// Set the model ID
148    #[must_use]
149    pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
150        self.model_id = model_id.into();
151        self
152    }
153
154    /// Set configuration
155    #[must_use]
156    pub fn with_config(mut self, config: EmbeddingConfig) -> Self {
157        self.config = config;
158        self
159    }
160
161    fn hash_to_vector(&self, text: &str) -> Vec<f32> {
162        use std::collections::hash_map::DefaultHasher;
163        use std::hash::{Hash, Hasher};
164
165        let mut vector = Vec::with_capacity(self.dimension);
166        let mut hasher = DefaultHasher::new();
167
168        for i in 0..self.dimension {
169            text.hash(&mut hasher);
170            i.hash(&mut hasher);
171            let hash = hasher.finish();
172            // Convert hash to float in range [-1, 1]
173            let value = (hash as f32 / u64::MAX as f32) * 2.0 - 1.0;
174            vector.push(value);
175        }
176
177        if self.config.normalize {
178            Self::normalize_vector(&mut vector);
179        }
180
181        vector
182    }
183
184    fn normalize_vector(vector: &mut [f32]) {
185        let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
186        if norm > 0.0 {
187            for x in vector.iter_mut() {
188                *x /= norm;
189            }
190        }
191    }
192}
193
194impl Embedder for MockEmbedder {
195    fn embed(&self, text: &str) -> Result<Vec<f32>> {
196        if text.is_empty() {
197            return Err(Error::EmptyDocument("empty text for embedding".to_string()));
198        }
199
200        let prefixed = if let Some(prefix) = &self.config.document_prefix {
201            format!("{prefix}{text}")
202        } else {
203            text.to_string()
204        };
205
206        Ok(self.hash_to_vector(&prefixed))
207    }
208
209    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
210        texts.iter().map(|t| self.embed(t)).collect()
211    }
212
213    fn dimension(&self) -> usize {
214        self.dimension
215    }
216
217    fn model_id(&self) -> &str {
218        &self.model_id
219    }
220
221    fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
222        if query.is_empty() {
223            return Err(Error::Query("empty query".to_string()));
224        }
225
226        let prefixed = if let Some(prefix) = &self.config.query_prefix {
227            format!("{prefix}{query}")
228        } else {
229            query.to_string()
230        };
231
232        Ok(self.hash_to_vector(&prefixed))
233    }
234}
235
236/// TF-IDF based embedder (sparse-to-dense conversion)
237#[derive(Debug, Clone)]
238pub struct TfIdfEmbedder {
239    dimension: usize,
240    vocabulary: std::collections::HashMap<String, usize>,
241    idf: Vec<f32>,
242}
243
244impl TfIdfEmbedder {
245    /// Create a new TF-IDF embedder (untrained)
246    #[must_use]
247    pub fn new(dimension: usize) -> Self {
248        Self { dimension, vocabulary: std::collections::HashMap::new(), idf: Vec::new() }
249    }
250
251    /// Train the embedder on a corpus
252    pub fn fit(&mut self, documents: &[&str]) {
253        use std::collections::{HashMap, HashSet};
254
255        let mut doc_freq: HashMap<String, usize> = HashMap::new();
256        let mut all_terms: HashSet<String> = HashSet::new();
257
258        for doc in documents {
259            let terms: HashSet<String> = doc.split_whitespace().map(|s| s.to_lowercase()).collect();
260
261            for term in &terms {
262                *doc_freq.entry(term.clone()).or_insert(0) += 1;
263                all_terms.insert(term.clone());
264            }
265        }
266
267        // Build vocabulary (top N terms by document frequency)
268        let mut terms: Vec<_> = all_terms.into_iter().collect();
269        terms.sort_by_key(|t| std::cmp::Reverse(doc_freq.get(t).copied().unwrap_or(0)));
270        terms.truncate(self.dimension);
271
272        self.vocabulary = terms.iter().enumerate().map(|(i, t)| (t.clone(), i)).collect();
273
274        // Compute IDF
275        let n = documents.len() as f32;
276        self.idf = terms
277            .iter()
278            .map(|t| {
279                let df = doc_freq.get(t).copied().unwrap_or(1) as f32;
280                (n / df).max(f32::EPSILON).ln() + 1.0
281            })
282            .collect();
283    }
284
285    fn compute_tf(&self, text: &str) -> Vec<f32> {
286        let mut tf = vec![0.0f32; self.dimension];
287        let terms: Vec<String> = text.split_whitespace().map(|s| s.to_lowercase()).collect();
288        let total = terms.len() as f32;
289
290        for term in terms {
291            if let Some(&idx) = self.vocabulary.get(&term) {
292                tf[idx] += 1.0 / total;
293            }
294        }
295
296        tf
297    }
298}
299
300impl Embedder for TfIdfEmbedder {
301    fn embed(&self, text: &str) -> Result<Vec<f32>> {
302        if text.is_empty() {
303            return Err(Error::EmptyDocument("empty text".to_string()));
304        }
305
306        if self.vocabulary.is_empty() {
307            return Err(Error::InvalidConfig("embedder not trained".to_string()));
308        }
309
310        let tf = self.compute_tf(text);
311        let mut tfidf: Vec<f32> = tf.iter().zip(self.idf.iter()).map(|(t, i)| t * i).collect();
312
313        // Normalize
314        let norm: f32 = tfidf.iter().map(|x| x * x).sum::<f32>().sqrt();
315        if norm > 0.0 {
316            for x in &mut tfidf {
317                *x /= norm;
318            }
319        }
320
321        // Pad to dimension if needed
322        tfidf.resize(self.dimension, 0.0);
323        Ok(tfidf)
324    }
325
326    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
327        texts.iter().map(|t| self.embed(t)).collect()
328    }
329
330    fn dimension(&self) -> usize {
331        self.dimension
332    }
333
334    fn model_id(&self) -> &str {
335        "tfidf"
336    }
337}
338
339/// Compute cosine similarity between two vectors
340#[must_use]
341/// Compute L2 (Euclidean) norm of a vector.
342fn l2_norm(v: &[f32]) -> f32 {
343    v.iter().map(|x| x * x).sum::<f32>().sqrt()
344}
345
346/// Divide numerator by denominator, returning 0.0 if denominator is zero.
347fn safe_divide(numerator: f32, denominator: f32) -> f32 {
348    if denominator == 0.0 {
349        0.0
350    } else {
351        numerator / denominator
352    }
353}
354
355/// Compute cosine similarity between two vectors.
356pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
357    if a.len() != b.len() {
358        return 0.0;
359    }
360    safe_divide(dot_product(a, b), l2_norm(a) * l2_norm(b))
361}
362
363/// Compute dot product between two vectors
364#[must_use]
365pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
366    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
367}
368
369/// Compute euclidean distance between two vectors
370#[must_use]
371pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
372    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt()
373}
374
375// ============================================================================
376// FastEmbed-based Embedder (GH-1: Production-ready semantic embeddings)
377// ============================================================================
378
379/// Available embedding models when `embeddings` feature is enabled
380#[cfg(feature = "embeddings")]
381#[derive(Debug, Clone, Copy, PartialEq, Eq)]
382pub enum EmbeddingModelType {
383    /// all-MiniLM-L6-v2: Fast, good quality (384 dims)
384    AllMiniLmL6V2,
385    /// all-MiniLM-L12-v2: Better quality, slightly slower (384 dims)
386    AllMiniLmL12V2,
387    /// BGE-small-en-v1.5: Balanced performance (384 dims)
388    BgeSmallEnV15,
389    /// BGE-base-en-v1.5: Higher quality (768 dims)
390    BgeBaseEnV15,
391    /// NomicEmbed-text-v1: Good for retrieval (768 dims)
392    NomicEmbedTextV1,
393}
394
395#[cfg(feature = "embeddings")]
396impl Default for EmbeddingModelType {
397    fn default() -> Self {
398        Self::AllMiniLmL6V2
399    }
400}
401
402#[cfg(feature = "embeddings")]
403impl EmbeddingModelType {
404    /// Get the fastembed model enum variant
405    fn to_fastembed_model(self) -> fastembed::EmbeddingModel {
406        match self {
407            Self::AllMiniLmL6V2 => fastembed::EmbeddingModel::AllMiniLML6V2,
408            Self::AllMiniLmL12V2 => fastembed::EmbeddingModel::AllMiniLML12V2,
409            Self::BgeSmallEnV15 => fastembed::EmbeddingModel::BGESmallENV15,
410            Self::BgeBaseEnV15 => fastembed::EmbeddingModel::BGEBaseENV15,
411            Self::NomicEmbedTextV1 => fastembed::EmbeddingModel::NomicEmbedTextV1,
412        }
413    }
414
415    /// Get the embedding dimension for this model
416    #[must_use]
417    pub const fn dimension(self) -> usize {
418        match self {
419            Self::AllMiniLmL6V2 | Self::AllMiniLmL12V2 | Self::BgeSmallEnV15 => 384,
420            Self::BgeBaseEnV15 | Self::NomicEmbedTextV1 => 768,
421        }
422    }
423
424    /// Get human-readable model name
425    #[must_use]
426    pub const fn model_name(self) -> &'static str {
427        match self {
428            Self::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
429            Self::AllMiniLmL12V2 => "sentence-transformers/all-MiniLM-L12-v2",
430            Self::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
431            Self::BgeBaseEnV15 => "BAAI/bge-base-en-v1.5",
432            Self::NomicEmbedTextV1 => "nomic-ai/nomic-embed-text-v1",
433        }
434    }
435}
436
437/// Production-ready semantic embedder using fastembed (ONNX Runtime)
438///
439/// Requires the `embeddings` feature to be enabled.
440///
441/// # Example
442///
443/// ```rust,ignore
444/// use trueno_rag::embed::{FastEmbedder, EmbeddingModelType, Embedder};
445///
446/// let embedder = FastEmbedder::new(EmbeddingModelType::AllMiniLmL6V2)?;
447/// let embedding = embedder.embed("Hello, world!")?;
448/// assert_eq!(embedding.len(), 384);
449/// ```
450#[cfg(feature = "embeddings")]
451#[derive(Clone)]
452pub struct FastEmbedder {
453    model: std::sync::Arc<std::sync::Mutex<fastembed::TextEmbedding>>,
454    model_type: EmbeddingModelType,
455}
456
457#[cfg(feature = "embeddings")]
458impl std::fmt::Debug for FastEmbedder {
459    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
460        f.debug_struct("FastEmbedder")
461            .field("model_type", &self.model_type)
462            .field("dimension", &self.model_type.dimension())
463            .finish_non_exhaustive() // model field intentionally omitted (not Debug)
464    }
465}
466
467#[cfg(feature = "embeddings")]
468impl FastEmbedder {
469    /// Create a new FastEmbedder with the specified model
470    ///
471    /// Downloads the model on first use if not cached.
472    ///
473    /// # Errors
474    /// Returns an error if model initialization fails.
475    pub fn new(model_type: EmbeddingModelType) -> Result<Self> {
476        let options = fastembed::InitOptions::new(model_type.to_fastembed_model())
477            .with_show_download_progress(true);
478
479        let model = fastembed::TextEmbedding::try_new(options).map_err(|e| {
480            Error::InvalidConfig(format!("Failed to initialize embedding model: {e}"))
481        })?;
482
483        Ok(Self { model: std::sync::Arc::new(std::sync::Mutex::new(model)), model_type })
484    }
485
486    /// Create with default model (all-MiniLM-L6-v2)
487    ///
488    /// # Errors
489    /// Returns an error if model initialization fails.
490    pub fn default_model() -> Result<Self> {
491        Self::new(EmbeddingModelType::default())
492    }
493
494    /// Get the model type
495    #[must_use]
496    pub fn model_type(&self) -> EmbeddingModelType {
497        self.model_type
498    }
499}
500
501#[cfg(feature = "embeddings")]
502impl Embedder for FastEmbedder {
503    fn embed(&self, text: &str) -> Result<Vec<f32>> {
504        if text.is_empty() {
505            return Err(Error::EmptyDocument("empty text for embedding".to_string()));
506        }
507
508        let mut model =
509            self.model.lock().map_err(|e| Error::Embedding(format!("lock failed: {e}")))?;
510
511        let embeddings = model
512            .embed(vec![text], None)
513            .map_err(|e| Error::Embedding(format!("embedding failed: {e}")))?;
514
515        embeddings
516            .into_iter()
517            .next()
518            .ok_or_else(|| Error::Embedding("no embedding returned".to_string()))
519    }
520
521    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
522        if texts.is_empty() {
523            return Ok(Vec::new());
524        }
525
526        // Filter out empty texts
527        let non_empty: Vec<&str> = texts.iter().copied().filter(|t| !t.is_empty()).collect();
528        if non_empty.is_empty() {
529            return Err(Error::EmptyDocument("all texts are empty".to_string()));
530        }
531
532        let mut model =
533            self.model.lock().map_err(|e| Error::Embedding(format!("lock failed: {e}")))?;
534
535        model
536            .embed(non_empty, None)
537            .map_err(|e| Error::Embedding(format!("batch embedding failed: {e}")))
538    }
539
540    fn dimension(&self) -> usize {
541        self.model_type.dimension()
542    }
543
544    fn model_id(&self) -> &str {
545        self.model_type.model_name()
546    }
547
548    fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
549        // Some models use query prefixes, but fastembed handles this internally
550        self.embed(query)
551    }
552
553    fn embed_document(&self, document: &str) -> Result<Vec<f32>> {
554        self.embed(document)
555    }
556}
557
558#[cfg(test)]
559mod tests;