smelt_memory/embedder/
fastembed_impl.rs1use super::traits::Embedder;
4use super::DEFAULT_DIMENSION;
5use crate::error::{MemoryError, MemoryResult};
6use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
7use std::sync::Arc;
8
9pub struct FastEmbedder {
11 model: Arc<TextEmbedding>,
12 dimension: usize,
13}
14
15impl FastEmbedder {
16 pub fn new() -> MemoryResult<Self> {
18 Self::with_model(EmbeddingModel::BGESmallENV15)
19 }
20
21 pub fn with_model(model: EmbeddingModel) -> MemoryResult<Self> {
23 let embedding =
24 TextEmbedding::try_new(InitOptions::new(model).with_show_download_progress(true))
25 .map_err(|e| {
26 MemoryError::Embedding(format!("Failed to initialize embedding model: {}", e))
27 })?;
28
29 let dimension = match embedding.embed(vec!["test"], None) {
31 Ok(embeddings) if !embeddings.is_empty() => embeddings[0].len(),
32 _ => DEFAULT_DIMENSION,
33 };
34
35 Ok(Self {
36 model: Arc::new(embedding),
37 dimension,
38 })
39 }
40
41 #[cfg(test)]
43 pub fn dummy() -> Self {
44 Self {
45 model: Arc::new(
46 TextEmbedding::try_new(InitOptions::new(EmbeddingModel::BGESmallENV15))
47 .expect("Failed to create test model"),
48 ),
49 dimension: DEFAULT_DIMENSION,
50 }
51 }
52}
53
54impl Embedder for FastEmbedder {
55 fn dimension(&self) -> usize {
56 self.dimension
57 }
58
59 fn embed(&self, text: &str) -> MemoryResult<Vec<f32>> {
60 let embeddings = self
61 .model
62 .embed(vec![text], None)
63 .map_err(|e| MemoryError::Embedding(format!("Embedding failed: {}", e)))?;
64
65 embeddings
66 .into_iter()
67 .next()
68 .ok_or_else(|| MemoryError::Embedding("No embedding generated".to_string()))
69 }
70
71 fn embed_batch(&self, texts: &[&str]) -> MemoryResult<Vec<Vec<f32>>> {
72 if texts.is_empty() {
73 return Ok(Vec::new());
74 }
75
76 let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
77 let texts_refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
78
79 self.model
80 .embed(texts_refs, None)
81 .map_err(|e| MemoryError::Embedding(format!("Batch embedding failed: {}", e)))
82 }
83}
84
85impl Default for FastEmbedder {
86 fn default() -> Self {
87 Self::new().expect("Failed to create default FastEmbedder")
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94
95 #[test]
99 #[ignore = "Requires model download"]
100 fn test_embed_single() {
101 let embedder = FastEmbedder::new().unwrap();
102 let embedding = embedder.embed("Hello, world!").unwrap();
103
104 assert_eq!(embedding.len(), embedder.dimension());
105 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
107 assert!((norm - 1.0).abs() < 0.1); }
109
110 #[test]
111 #[ignore = "Requires model download"]
112 fn test_embed_batch() {
113 let embedder = FastEmbedder::new().unwrap();
114 let embeddings = embedder
115 .embed_batch(&["First text", "Second text", "Third text"])
116 .unwrap();
117
118 assert_eq!(embeddings.len(), 3);
119 for emb in &embeddings {
120 assert_eq!(emb.len(), embedder.dimension());
121 }
122 }
123
124 #[test]
125 #[ignore = "Requires model download"]
126 fn test_similar_texts() {
127 let embedder = FastEmbedder::new().unwrap();
128
129 let e1 = embedder.embed("Fix authentication bug in login").unwrap();
130 let e2 = embedder.embed("Repair auth issue in sign-in").unwrap();
131 let e3 = embedder.embed("Add new database migration").unwrap();
132
133 let sim_12 = cosine_sim(&e1, &e2);
135 let sim_13 = cosine_sim(&e1, &e3);
136
137 assert!(
138 sim_12 > sim_13,
139 "Similar texts should have higher similarity"
140 );
141 }
142
143 #[cfg(test)]
144 fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
145 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
146 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
147 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
148 dot / (norm_a * norm_b)
149 }
150}