Skip to main content

smelt_memory/embedder/
fastembed_impl.rs

1//! FastEmbed-based embedding implementation
2
3use super::traits::Embedder;
4use super::DEFAULT_DIMENSION;
5use crate::error::{MemoryError, MemoryResult};
6use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
7use std::sync::Arc;
8
9/// FastEmbed-based embedder using BGE-Small model
10pub struct FastEmbedder {
11    model: Arc<TextEmbedding>,
12    dimension: usize,
13}
14
15impl FastEmbedder {
16    /// Create a new FastEmbedder with the default BGE-Small model
17    pub fn new() -> MemoryResult<Self> {
18        Self::with_model(EmbeddingModel::BGESmallENV15)
19    }
20
21    /// Create a FastEmbedder with a specific model
22    pub fn with_model(model: EmbeddingModel) -> MemoryResult<Self> {
23        let embedding =
24            TextEmbedding::try_new(InitOptions::new(model).with_show_download_progress(true))
25                .map_err(|e| {
26                    MemoryError::Embedding(format!("Failed to initialize embedding model: {}", e))
27                })?;
28
29        // Get dimension from first test embedding
30        let dimension = match embedding.embed(vec!["test"], None) {
31            Ok(embeddings) if !embeddings.is_empty() => embeddings[0].len(),
32            _ => DEFAULT_DIMENSION,
33        };
34
35        Ok(Self {
36            model: Arc::new(embedding),
37            dimension,
38        })
39    }
40
41    /// Create a dummy embedder for testing (returns random vectors)
42    #[cfg(test)]
43    pub fn dummy() -> Self {
44        Self {
45            model: Arc::new(
46                TextEmbedding::try_new(InitOptions::new(EmbeddingModel::BGESmallENV15))
47                    .expect("Failed to create test model"),
48            ),
49            dimension: DEFAULT_DIMENSION,
50        }
51    }
52}
53
54impl Embedder for FastEmbedder {
55    fn dimension(&self) -> usize {
56        self.dimension
57    }
58
59    fn embed(&self, text: &str) -> MemoryResult<Vec<f32>> {
60        let embeddings = self
61            .model
62            .embed(vec![text], None)
63            .map_err(|e| MemoryError::Embedding(format!("Embedding failed: {}", e)))?;
64
65        embeddings
66            .into_iter()
67            .next()
68            .ok_or_else(|| MemoryError::Embedding("No embedding generated".to_string()))
69    }
70
71    fn embed_batch(&self, texts: &[&str]) -> MemoryResult<Vec<Vec<f32>>> {
72        if texts.is_empty() {
73            return Ok(Vec::new());
74        }
75
76        let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
77        let texts_refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
78
79        self.model
80            .embed(texts_refs, None)
81            .map_err(|e| MemoryError::Embedding(format!("Batch embedding failed: {}", e)))
82    }
83}
84
85impl Default for FastEmbedder {
86    fn default() -> Self {
87        Self::new().expect("Failed to create default FastEmbedder")
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94
95    // Note: These tests require downloading the model on first run
96    // They are marked as ignored for CI but can be run manually
97
98    #[test]
99    #[ignore = "Requires model download"]
100    fn test_embed_single() {
101        let embedder = FastEmbedder::new().unwrap();
102        let embedding = embedder.embed("Hello, world!").unwrap();
103
104        assert_eq!(embedding.len(), embedder.dimension());
105        // Check that values are reasonable (normalized)
106        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
107        assert!((norm - 1.0).abs() < 0.1); // Should be roughly normalized
108    }
109
110    #[test]
111    #[ignore = "Requires model download"]
112    fn test_embed_batch() {
113        let embedder = FastEmbedder::new().unwrap();
114        let embeddings = embedder
115            .embed_batch(&["First text", "Second text", "Third text"])
116            .unwrap();
117
118        assert_eq!(embeddings.len(), 3);
119        for emb in &embeddings {
120            assert_eq!(emb.len(), embedder.dimension());
121        }
122    }
123
124    #[test]
125    #[ignore = "Requires model download"]
126    fn test_similar_texts() {
127        let embedder = FastEmbedder::new().unwrap();
128
129        let e1 = embedder.embed("Fix authentication bug in login").unwrap();
130        let e2 = embedder.embed("Repair auth issue in sign-in").unwrap();
131        let e3 = embedder.embed("Add new database migration").unwrap();
132
133        // Similar texts should have higher cosine similarity
134        let sim_12 = cosine_sim(&e1, &e2);
135        let sim_13 = cosine_sim(&e1, &e3);
136
137        assert!(
138            sim_12 > sim_13,
139            "Similar texts should have higher similarity"
140        );
141    }
142
143    #[cfg(test)]
144    fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
145        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
146        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
147        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
148        dot / (norm_a * norm_b)
149    }
150}