tandem_memory/
embeddings.rs1use crate::types::{
5 MemoryError, MemoryResult, DEFAULT_EMBEDDING_DIMENSION, DEFAULT_EMBEDDING_MODEL,
6};
7#[cfg(feature = "local-embeddings")]
8use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
9use once_cell::sync::OnceCell;
10#[cfg(feature = "local-embeddings")]
11use std::path::PathBuf;
12use std::sync::Arc;
13use tokio::sync::Mutex;
14
15#[cfg(feature = "local-embeddings")]
16type EmbeddingBackend = TextEmbedding;
17#[cfg(not(feature = "local-embeddings"))]
18type EmbeddingBackend = ();
19
20pub struct EmbeddingService {
22 model_name: String,
23 dimension: usize,
24 model: Option<EmbeddingBackend>,
25 disabled_reason: Option<String>,
26}
27
28impl EmbeddingService {
29 pub fn new() -> Self {
31 Self::with_model(
32 DEFAULT_EMBEDDING_MODEL.to_string(),
33 DEFAULT_EMBEDDING_DIMENSION,
34 )
35 }
36
37 pub fn with_model(model_name: String, dimension: usize) -> Self {
39 let (model, disabled_reason) = Self::init_model(&model_name);
40
41 if let Some(reason) = &disabled_reason {
42 tracing::warn!(
43 target: "tandem.memory",
44 "Embeddings disabled: model={} reason={}",
45 model_name,
46 reason
47 );
48 } else {
49 tracing::info!(
50 target: "tandem.memory",
51 "Embeddings enabled: model={} dimension={}",
52 model_name,
53 dimension
54 );
55 }
56
57 Self {
58 model_name,
59 dimension,
60 model,
61 disabled_reason,
62 }
63 }
64
65 fn init_model(model_name: &str) -> (Option<EmbeddingBackend>, Option<String>) {
66 #[cfg(not(feature = "local-embeddings"))]
67 {
68 let _ = model_name;
69 return (
70 None,
71 Some("local embeddings are disabled at build time".to_string()),
72 );
73 }
74
75 #[cfg(feature = "local-embeddings")]
76 {
77 if let Some(reason) = embeddings_runtime_disabled_reason() {
78 return (None, Some(reason));
79 }
80
81 let Some(parsed_model) = Self::parse_model_id(model_name) else {
82 return (
83 None,
84 Some(format!(
85 "unsupported embedding model id '{}'; supported: {}",
86 model_name, DEFAULT_EMBEDDING_MODEL
87 )),
88 );
89 };
90
91 let cache_dir = resolve_embedding_cache_dir();
92 let options = InitOptions::new(parsed_model).with_cache_dir(cache_dir.clone());
93
94 tracing::info!(
95 target: "tandem.memory",
96 "Initializing embeddings with cache dir: {}",
97 cache_dir.display()
98 );
99
100 match TextEmbedding::try_new(options) {
101 Ok(model) => (Some(model), None),
102 Err(err) => (
103 None,
104 Some(format!(
105 "failed to initialize embedding model '{}': {}",
106 model_name, err
107 )),
108 ),
109 }
110 }
111 }
112
113 #[cfg(feature = "local-embeddings")]
114 fn parse_model_id(model_name: &str) -> Option<EmbeddingModel> {
115 match model_name.trim().to_ascii_lowercase().as_str() {
116 "all-minilm-l6-v2" | "all_minilm_l6_v2" => Some(EmbeddingModel::AllMiniLML6V2),
117 _ => None,
118 }
119 }
120
121 pub fn dimension(&self) -> usize {
123 self.dimension
124 }
125
126 pub fn model_name(&self) -> &str {
128 &self.model_name
129 }
130
131 pub fn is_available(&self) -> bool {
133 self.model.is_some()
134 }
135
136 pub fn disabled_reason(&self) -> Option<&str> {
138 self.disabled_reason.as_deref()
139 }
140
141 fn unavailable_error(&self) -> MemoryError {
142 let reason = self
143 .disabled_reason
144 .as_deref()
145 .unwrap_or("embedding backend unavailable");
146 MemoryError::Embedding(format!("embeddings disabled: {reason}"))
147 }
148
149 #[cfg(feature = "local-embeddings")]
150 fn ensure_dimension(&self, embedding: &[f32]) -> MemoryResult<()> {
151 if embedding.len() != self.dimension {
152 return Err(MemoryError::Embedding(format!(
153 "embedding dimension mismatch: expected {}, got {}",
154 self.dimension,
155 embedding.len()
156 )));
157 }
158 Ok(())
159 }
160
161 pub async fn embed(&self, text: &str) -> MemoryResult<Vec<f32>> {
163 #[cfg(not(feature = "local-embeddings"))]
164 {
165 let _ = text;
166 return Err(self.unavailable_error());
167 }
168
169 #[cfg(feature = "local-embeddings")]
170 {
171 let Some(model) = self.model.as_ref() else {
172 return Err(self.unavailable_error());
173 };
174
175 let mut embeddings = model
176 .embed(vec![text.to_string()], None)
177 .map_err(|e| MemoryError::Embedding(e.to_string()))?;
178 let embedding = embeddings
179 .pop()
180 .ok_or_else(|| MemoryError::Embedding("no embedding generated".to_string()))?;
181 self.ensure_dimension(&embedding)?;
182 Ok(embedding)
183 }
184 }
185
186 pub async fn embed_batch(&self, texts: &[String]) -> MemoryResult<Vec<Vec<f32>>> {
188 #[cfg(not(feature = "local-embeddings"))]
189 {
190 let _ = texts;
191 return Err(self.unavailable_error());
192 }
193
194 #[cfg(feature = "local-embeddings")]
195 {
196 let Some(model) = self.model.as_ref() else {
197 return Err(self.unavailable_error());
198 };
199
200 let embeddings = model
201 .embed(texts.to_vec(), None)
202 .map_err(|e| MemoryError::Embedding(e.to_string()))?;
203
204 for embedding in &embeddings {
205 self.ensure_dimension(embedding)?;
206 }
207
208 Ok(embeddings)
209 }
210 }
211
212 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
214 if a.len() != b.len() {
215 return 0.0;
216 }
217
218 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
219 let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
220 let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
221
222 if magnitude_a == 0.0 || magnitude_b == 0.0 {
223 0.0
224 } else {
225 dot_product / (magnitude_a * magnitude_b)
226 }
227 }
228
229 pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
231 a.iter()
232 .zip(b.iter())
233 .map(|(x, y)| (x - y).powi(2))
234 .sum::<f32>()
235 .sqrt()
236 }
237}
238
239#[cfg(feature = "local-embeddings")]
240fn embeddings_runtime_disabled_reason() -> Option<String> {
241 let disable = std::env::var("TANDEM_DISABLE_EMBEDDINGS")
242 .ok()
243 .map(|v| v.trim().to_ascii_lowercase())
244 .filter(|v| !v.is_empty())
245 .map(|v| matches!(v.as_str(), "1" | "true" | "yes" | "on"))
246 .unwrap_or(false);
247
248 if disable {
249 return Some("disabled by TANDEM_DISABLE_EMBEDDINGS".to_string());
250 }
251 None
252}
253
254#[cfg(feature = "local-embeddings")]
255fn resolve_embedding_cache_dir() -> PathBuf {
256 if let Ok(explicit) = std::env::var("FASTEMBED_CACHE_DIR") {
257 let explicit_path = PathBuf::from(explicit);
258 if let Err(err) = std::fs::create_dir_all(&explicit_path) {
259 tracing::warn!(
260 target: "tandem.memory",
261 "Failed to create FASTEMBED_CACHE_DIR {:?}: {}",
262 explicit_path,
263 err
264 );
265 }
266 return explicit_path;
267 }
268
269 let base = dirs::data_local_dir()
270 .or_else(dirs::cache_dir)
271 .unwrap_or_else(std::env::temp_dir);
272 let cache_dir = base.join("tandem").join("fastembed");
273
274 if let Err(err) = std::fs::create_dir_all(&cache_dir) {
275 tracing::warn!(
276 target: "tandem.memory",
277 "Failed to create embedding cache directory {:?}: {}",
278 cache_dir,
279 err
280 );
281 }
282
283 cache_dir
284}
285
286impl Default for EmbeddingService {
287 fn default() -> Self {
288 Self::new()
289 }
290}
291
292static EMBEDDING_SERVICE: OnceCell<Arc<Mutex<EmbeddingService>>> = OnceCell::new();
294
295pub async fn get_embedding_service() -> Arc<Mutex<EmbeddingService>> {
297 EMBEDDING_SERVICE
298 .get_or_init(|| Arc::new(Mutex::new(EmbeddingService::new())))
299 .clone()
300}
301
302pub fn init_embedding_service(model_name: Option<String>, dimension: Option<usize>) {
304 let service = if let (Some(name), Some(dim)) = (model_name, dimension) {
305 EmbeddingService::with_model(name, dim)
306 } else {
307 EmbeddingService::new()
308 };
309
310 let _ = EMBEDDING_SERVICE.set(Arc::new(Mutex::new(service)));
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[tokio::test]
318 async fn test_embedding_dimension_or_unavailable() {
319 let service = EmbeddingService::new();
320
321 if !service.is_available() {
322 let err = service.embed("Hello world").await.unwrap_err();
323 assert!(err.to_string().contains("embeddings disabled"));
324 return;
325 }
326
327 let embedding = service.embed("Hello world").await.unwrap();
328 assert_eq!(embedding.len(), DEFAULT_EMBEDDING_DIMENSION);
329 }
330
331 #[tokio::test]
332 async fn test_embed_batch_or_unavailable() {
333 let service = EmbeddingService::new();
334 let texts = vec![
335 "First text".to_string(),
336 "Second text".to_string(),
337 "Third text".to_string(),
338 ];
339
340 let result = service.embed_batch(&texts).await;
341 if !service.is_available() {
342 assert!(result.is_err());
343 return;
344 }
345
346 let embeddings = result.unwrap();
347 assert_eq!(embeddings.len(), 3);
348 for emb in &embeddings {
349 assert_eq!(emb.len(), DEFAULT_EMBEDDING_DIMENSION);
350 }
351 }
352
353 #[test]
354 fn test_cosine_similarity() {
355 let a = vec![1.0f32, 0.0, 0.0];
356 let b = vec![1.0f32, 0.0, 0.0];
357 let c = vec![0.0f32, 1.0, 0.0];
358
359 let sim_same = EmbeddingService::cosine_similarity(&a, &b);
360 let sim_orthogonal = EmbeddingService::cosine_similarity(&a, &c);
361
362 assert!((sim_same - 1.0).abs() < 1e-6);
363 assert!(sim_orthogonal.abs() < 1e-6);
364 }
365}