1use std::num::NonZeroUsize;
19use std::sync::Arc;
20
21use anyhow::{Context, Result};
22use async_trait::async_trait;
23use fastembed::{EmbeddingModel, TextEmbedding, TextInitOptions};
24use lru::LruCache;
25use parking_lot::Mutex;
26
27pub const EMBED_DIM: usize = 384;
33
34pub const DEFAULT_CACHE_CAPACITY: usize = 256;
38
39#[async_trait]
47pub trait Embedder: Send + Sync {
48 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
51
52 fn dimension(&self) -> usize;
54}
55
56pub async fn embed_one(embedder: &dyn Embedder, text: &str) -> Result<Vec<f32>> {
65 let mut v = embedder.embed_batch(&[text.to_string()]).await?;
66 v.pop()
67 .context("embedder returned no embedding for non-empty input")
68}
69
70pub struct FastEmbedder {
83 model: Arc<Mutex<TextEmbedding>>,
84 cache: Arc<Mutex<LruCache<String, Vec<f32>>>>,
85 dim: usize,
86}
87
88impl FastEmbedder {
89 pub async fn new() -> Result<Self> {
91 Self::with_cache_size(DEFAULT_CACHE_CAPACITY).await
92 }
93
94 pub async fn with_cache_size(capacity: usize) -> Result<Self> {
96 let capacity =
97 NonZeroUsize::new(capacity.max(1)).expect("capacity.max(1) is always non-zero");
98
99 let model = tokio::task::spawn_blocking(|| -> Result<TextEmbedding> {
102 let mut m =
103 TextEmbedding::try_new(TextInitOptions::new(EmbeddingModel::AllMiniLML6V2Q))
104 .or_else(|q_err| {
105 tracing::warn!(
106 "AllMiniLML6V2Q init failed ({q_err:#}), falling back to AllMiniLML6V2"
107 );
108 TextEmbedding::try_new(TextInitOptions::new(EmbeddingModel::AllMiniLML6V2))
109 })
110 .context(
111 "failed to initialise fastembed (tried AllMiniLML6V2Q and AllMiniLML6V2)",
112 )?;
113
114 let warmup: Vec<&str> = vec![
116 "hello world",
117 "the quick brown fox",
118 "memory palace warmup",
119 "embedding model ready",
120 "trusty common warmup",
121 ];
122 let _ = m
123 .embed(warmup, None)
124 .context("fastembed warmup batch failed")?;
125 Ok(m)
126 })
127 .await
128 .context("spawn_blocking joined with error during embedder init")??;
129
130 Ok(Self {
131 model: Arc::new(Mutex::new(model)),
132 cache: Arc::new(Mutex::new(LruCache::new(capacity))),
133 dim: EMBED_DIM,
134 })
135 }
136}
137
138#[async_trait]
139impl Embedder for FastEmbedder {
140 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
141 if texts.is_empty() {
142 return Ok(Vec::new());
143 }
144
145 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
147 let mut to_compute: Vec<(usize, String)> = Vec::new();
148 {
149 let mut cache = self.cache.lock();
150 for (i, t) in texts.iter().enumerate() {
151 if let Some(v) = cache.get(t) {
152 results[i] = Some(v.clone());
153 } else {
154 to_compute.push((i, t.clone()));
155 }
156 }
157 }
158
159 if !to_compute.is_empty() {
160 let model = Arc::clone(&self.model);
161 let owned: Vec<String> = to_compute.iter().map(|(_, s)| s.clone()).collect();
162 let computed = tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
163 let mut guard = model.lock();
164 guard
165 .embed(owned, None)
166 .context("fastembed embed call failed")
167 })
168 .await
169 .context("spawn_blocking joined with error during embed")??;
170
171 if computed.len() != to_compute.len() {
172 anyhow::bail!(
173 "fastembed returned {} embeddings, expected {}",
174 computed.len(),
175 to_compute.len()
176 );
177 }
178
179 let mut cache = self.cache.lock();
180 for ((idx, key), vector) in to_compute.into_iter().zip(computed.into_iter()) {
181 cache.put(key, vector.clone());
182 results[idx] = Some(vector);
183 }
184 }
185
186 results
187 .into_iter()
188 .map(|opt| opt.context("missing embedding slot after batch"))
189 .collect()
190 }
191
192 fn dimension(&self) -> usize {
193 self.dim
194 }
195}
196
197#[cfg(any(test, feature = "test-support"))]
206pub struct MockEmbedder {
207 dim: usize,
208}
209
210#[cfg(any(test, feature = "test-support"))]
211impl MockEmbedder {
212 pub fn new(dim: usize) -> Self {
213 Self { dim }
214 }
215
216 fn hash_to_vec(&self, text: &str) -> Vec<f32> {
217 let mut v = vec![0.0_f32; self.dim];
218 for (i, b) in text.bytes().enumerate() {
219 let slot = (i + b as usize) % self.dim;
220 v[slot] += (b as f32) / 255.0;
221 }
222 if let Some(first) = text.bytes().next() {
223 v[0] += first as f32 / 255.0;
224 }
225 v
226 }
227}
228
229#[cfg(any(test, feature = "test-support"))]
230#[async_trait]
231impl Embedder for MockEmbedder {
232 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
233 Ok(texts.iter().map(|t| self.hash_to_vec(t)).collect())
234 }
235
236 fn dimension(&self) -> usize {
237 self.dim
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[tokio::test]
246 async fn mock_embedder_round_trip() {
247 let e = MockEmbedder::new(EMBED_DIM);
248 assert_eq!(e.dimension(), EMBED_DIM);
249 let v = embed_one(&e, "hello").await.unwrap();
250 assert_eq!(v.len(), EMBED_DIM);
251 let batch = e
252 .embed_batch(&["a".to_string(), "b".to_string()])
253 .await
254 .unwrap();
255 assert_eq!(batch.len(), 2);
256 assert_ne!(batch[0], batch[1]);
257 }
258
259 #[tokio::test]
260 async fn mock_embedder_empty_input_returns_empty() {
261 let e = MockEmbedder::new(EMBED_DIM);
262 let v = e.embed_batch(&[]).await.unwrap();
263 assert!(v.is_empty());
264 }
265
266 #[tokio::test]
269 #[ignore]
270 async fn fastembed_returns_correct_dim() {
271 let e = FastEmbedder::new().await.unwrap();
272 assert_eq!(e.dimension(), 384);
273 let v = embed_one(&e, "fn authenticate(user: &str) -> bool")
274 .await
275 .unwrap();
276 assert_eq!(v.len(), 384);
277 assert!(v.iter().any(|x| *x != 0.0));
278 }
279
280 #[tokio::test]
281 #[ignore]
282 async fn fastembed_cache_hit_is_idempotent() {
283 let e = FastEmbedder::new().await.unwrap();
284 let v1 = embed_one(&e, "cached").await.unwrap();
285 let v2 = embed_one(&e, "cached").await.unwrap();
286 assert_eq!(v1, v2);
287 }
288}