tandem_memory/
embeddings.rs1use crate::types::{
5 MemoryError, MemoryResult, DEFAULT_EMBEDDING_DIMENSION, DEFAULT_EMBEDDING_MODEL,
6};
7use once_cell::sync::OnceCell;
8#[cfg(feature = "local-embeddings")]
9use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
10#[cfg(feature = "local-embeddings")]
11use std::path::PathBuf;
12use std::sync::Arc;
13use tokio::sync::Mutex;
14
15#[cfg(feature = "local-embeddings")]
16type EmbeddingBackend = TextEmbedding;
17#[cfg(not(feature = "local-embeddings"))]
18type EmbeddingBackend = ();
19
20pub struct EmbeddingService {
22 model_name: String,
23 dimension: usize,
24 model: Option<EmbeddingBackend>,
25 disabled_reason: Option<String>,
26}
27
28impl EmbeddingService {
29 pub fn new() -> Self {
31 Self::with_model(
32 DEFAULT_EMBEDDING_MODEL.to_string(),
33 DEFAULT_EMBEDDING_DIMENSION,
34 )
35 }
36
37 pub fn with_model(model_name: String, dimension: usize) -> Self {
39 let (model, disabled_reason) = Self::init_model(&model_name);
40
41 if let Some(reason) = &disabled_reason {
42 tracing::warn!(
43 target: "tandem.memory",
44 "Embeddings disabled: model={} reason={}",
45 model_name,
46 reason
47 );
48 } else {
49 tracing::info!(
50 target: "tandem.memory",
51 "Embeddings enabled: model={} dimension={}",
52 model_name,
53 dimension
54 );
55 }
56
57 Self {
58 model_name,
59 dimension,
60 model,
61 disabled_reason,
62 }
63 }
64
65 fn init_model(model_name: &str) -> (Option<EmbeddingBackend>, Option<String>) {
66 #[cfg(not(feature = "local-embeddings"))]
67 {
68 let _ = model_name;
69 return (
70 None,
71 Some("local embeddings are disabled at build time".to_string()),
72 );
73 }
74
75 #[cfg(feature = "local-embeddings")]
76 {
77 let Some(parsed_model) = Self::parse_model_id(model_name) else {
78 return (
79 None,
80 Some(format!(
81 "unsupported embedding model id '{}'; supported: {}",
82 model_name, DEFAULT_EMBEDDING_MODEL
83 )),
84 );
85 };
86
87 let cache_dir = resolve_embedding_cache_dir();
88 let options = InitOptions::new(parsed_model).with_cache_dir(cache_dir.clone());
89
90 tracing::info!(
91 target: "tandem.memory",
92 "Initializing embeddings with cache dir: {}",
93 cache_dir.display()
94 );
95
96 match TextEmbedding::try_new(options) {
97 Ok(model) => (Some(model), None),
98 Err(err) => (
99 None,
100 Some(format!(
101 "failed to initialize embedding model '{}': {}",
102 model_name, err
103 )),
104 ),
105 }
106 }
107 }
108
109 #[cfg(feature = "local-embeddings")]
110 fn parse_model_id(model_name: &str) -> Option<EmbeddingModel> {
111 match model_name.trim().to_ascii_lowercase().as_str() {
112 "all-minilm-l6-v2" | "all_minilm_l6_v2" => Some(EmbeddingModel::AllMiniLML6V2),
113 _ => None,
114 }
115 }
116
117 pub fn dimension(&self) -> usize {
119 self.dimension
120 }
121
122 pub fn model_name(&self) -> &str {
124 &self.model_name
125 }
126
127 pub fn is_available(&self) -> bool {
129 self.model.is_some()
130 }
131
132 pub fn disabled_reason(&self) -> Option<&str> {
134 self.disabled_reason.as_deref()
135 }
136
137 fn unavailable_error(&self) -> MemoryError {
138 let reason = self
139 .disabled_reason
140 .as_deref()
141 .unwrap_or("embedding backend unavailable");
142 MemoryError::Embedding(format!("embeddings disabled: {reason}"))
143 }
144
145 #[cfg(feature = "local-embeddings")]
146 fn ensure_dimension(&self, embedding: &[f32]) -> MemoryResult<()> {
147 if embedding.len() != self.dimension {
148 return Err(MemoryError::Embedding(format!(
149 "embedding dimension mismatch: expected {}, got {}",
150 self.dimension,
151 embedding.len()
152 )));
153 }
154 Ok(())
155 }
156
157 pub async fn embed(&self, text: &str) -> MemoryResult<Vec<f32>> {
159 #[cfg(not(feature = "local-embeddings"))]
160 {
161 let _ = text;
162 return Err(self.unavailable_error());
163 }
164
165 #[cfg(feature = "local-embeddings")]
166 {
167 let Some(model) = self.model.as_ref() else {
168 return Err(self.unavailable_error());
169 };
170
171 let mut embeddings = model
172 .embed(vec![text.to_string()], None)
173 .map_err(|e| MemoryError::Embedding(e.to_string()))?;
174 let embedding = embeddings
175 .pop()
176 .ok_or_else(|| MemoryError::Embedding("no embedding generated".to_string()))?;
177 self.ensure_dimension(&embedding)?;
178 Ok(embedding)
179 }
180 }
181
182 pub async fn embed_batch(&self, texts: &[String]) -> MemoryResult<Vec<Vec<f32>>> {
184 #[cfg(not(feature = "local-embeddings"))]
185 {
186 let _ = texts;
187 return Err(self.unavailable_error());
188 }
189
190 #[cfg(feature = "local-embeddings")]
191 {
192 let Some(model) = self.model.as_ref() else {
193 return Err(self.unavailable_error());
194 };
195
196 let embeddings = model
197 .embed(texts.to_vec(), None)
198 .map_err(|e| MemoryError::Embedding(e.to_string()))?;
199
200 for embedding in &embeddings {
201 self.ensure_dimension(embedding)?;
202 }
203
204 Ok(embeddings)
205 }
206 }
207
208 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
210 if a.len() != b.len() {
211 return 0.0;
212 }
213
214 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
215 let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
216 let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
217
218 if magnitude_a == 0.0 || magnitude_b == 0.0 {
219 0.0
220 } else {
221 dot_product / (magnitude_a * magnitude_b)
222 }
223 }
224
225 pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
227 a.iter()
228 .zip(b.iter())
229 .map(|(x, y)| (x - y).powi(2))
230 .sum::<f32>()
231 .sqrt()
232 }
233}
234
235#[cfg(feature = "local-embeddings")]
236fn resolve_embedding_cache_dir() -> PathBuf {
237 if let Ok(explicit) = std::env::var("FASTEMBED_CACHE_DIR") {
238 let explicit_path = PathBuf::from(explicit);
239 if let Err(err) = std::fs::create_dir_all(&explicit_path) {
240 tracing::warn!(
241 target: "tandem.memory",
242 "Failed to create FASTEMBED_CACHE_DIR {:?}: {}",
243 explicit_path,
244 err
245 );
246 }
247 return explicit_path;
248 }
249
250 let base = dirs::data_local_dir()
251 .or_else(dirs::cache_dir)
252 .unwrap_or_else(std::env::temp_dir);
253 let cache_dir = base.join("tandem").join("fastembed");
254
255 if let Err(err) = std::fs::create_dir_all(&cache_dir) {
256 tracing::warn!(
257 target: "tandem.memory",
258 "Failed to create embedding cache directory {:?}: {}",
259 cache_dir,
260 err
261 );
262 }
263
264 cache_dir
265}
266
267impl Default for EmbeddingService {
268 fn default() -> Self {
269 Self::new()
270 }
271}
272
273static EMBEDDING_SERVICE: OnceCell<Arc<Mutex<EmbeddingService>>> = OnceCell::new();
275
276pub async fn get_embedding_service() -> Arc<Mutex<EmbeddingService>> {
278 EMBEDDING_SERVICE
279 .get_or_init(|| Arc::new(Mutex::new(EmbeddingService::new())))
280 .clone()
281}
282
283pub fn init_embedding_service(model_name: Option<String>, dimension: Option<usize>) {
285 let service = if let (Some(name), Some(dim)) = (model_name, dimension) {
286 EmbeddingService::with_model(name, dim)
287 } else {
288 EmbeddingService::new()
289 };
290
291 let _ = EMBEDDING_SERVICE.set(Arc::new(Mutex::new(service)));
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[tokio::test]
299 async fn test_embedding_dimension_or_unavailable() {
300 let service = EmbeddingService::new();
301
302 if !service.is_available() {
303 let err = service.embed("Hello world").await.unwrap_err();
304 assert!(err.to_string().contains("embeddings disabled"));
305 return;
306 }
307
308 let embedding = service.embed("Hello world").await.unwrap();
309 assert_eq!(embedding.len(), DEFAULT_EMBEDDING_DIMENSION);
310 }
311
312 #[tokio::test]
313 async fn test_embed_batch_or_unavailable() {
314 let service = EmbeddingService::new();
315 let texts = vec![
316 "First text".to_string(),
317 "Second text".to_string(),
318 "Third text".to_string(),
319 ];
320
321 let result = service.embed_batch(&texts).await;
322 if !service.is_available() {
323 assert!(result.is_err());
324 return;
325 }
326
327 let embeddings = result.unwrap();
328 assert_eq!(embeddings.len(), 3);
329 for emb in &embeddings {
330 assert_eq!(emb.len(), DEFAULT_EMBEDDING_DIMENSION);
331 }
332 }
333
334 #[test]
335 fn test_cosine_similarity() {
336 let a = vec![1.0f32, 0.0, 0.0];
337 let b = vec![1.0f32, 0.0, 0.0];
338 let c = vec![0.0f32, 1.0, 0.0];
339
340 let sim_same = EmbeddingService::cosine_similarity(&a, &b);
341 let sim_orthogonal = EmbeddingService::cosine_similarity(&a, &c);
342
343 assert!((sim_same - 1.0).abs() < 1e-6);
344 assert!(sim_orthogonal.abs() < 1e-6);
345 }
346}