1use crate::types::{
5 MemoryError, MemoryResult, DEFAULT_EMBEDDING_DIMENSION, DEFAULT_EMBEDDING_MODEL,
6};
7#[cfg(feature = "local-embeddings")]
8use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
9use once_cell::sync::OnceCell;
10use sha2::{Digest, Sha256};
11#[cfg(feature = "local-embeddings")]
12use std::path::PathBuf;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15
16#[cfg(feature = "local-embeddings")]
17type EmbeddingBackend = TextEmbedding;
18#[cfg(not(feature = "local-embeddings"))]
19type EmbeddingBackend = ();
20
21pub struct EmbeddingService {
23 model_name: String,
24 dimension: usize,
25 model: Option<EmbeddingBackend>,
26 disabled_reason: Option<String>,
27 deterministic: bool,
28}
29
30impl EmbeddingService {
31 pub fn new() -> Self {
33 Self::with_model(
34 DEFAULT_EMBEDDING_MODEL.to_string(),
35 DEFAULT_EMBEDDING_DIMENSION,
36 )
37 }
38
39 pub fn with_model(model_name: String, dimension: usize) -> Self {
41 let (model, disabled_reason) = Self::init_model(&model_name);
42
43 if let Some(reason) = &disabled_reason {
44 tracing::warn!(
45 target: "tandem.memory",
46 "Embeddings disabled: model={} reason={}",
47 model_name,
48 reason
49 );
50 } else {
51 tracing::info!(
52 target: "tandem.memory",
53 "Embeddings enabled: model={} dimension={}",
54 model_name,
55 dimension
56 );
57 }
58
59 Self {
60 model_name,
61 dimension,
62 model,
63 disabled_reason,
64 deterministic: false,
65 }
66 }
67
68 pub fn deterministic_for_tests(dimension: usize) -> Self {
72 Self {
73 model_name: "deterministic-test-embedding".to_string(),
74 dimension,
75 model: None,
76 disabled_reason: None,
77 deterministic: true,
78 }
79 }
80
81 fn init_model(model_name: &str) -> (Option<EmbeddingBackend>, Option<String>) {
82 #[cfg(not(feature = "local-embeddings"))]
83 {
84 let _ = model_name;
85 (
86 None,
87 Some("local embeddings are disabled at build time".to_string()),
88 )
89 }
90
91 #[cfg(feature = "local-embeddings")]
92 {
93 if let Some(reason) = embeddings_runtime_disabled_reason() {
94 return (None, Some(reason));
95 }
96
97 let Some(parsed_model) = Self::parse_model_id(model_name) else {
98 return (
99 None,
100 Some(format!(
101 "unsupported embedding model id '{}'; supported: {}",
102 model_name, DEFAULT_EMBEDDING_MODEL
103 )),
104 );
105 };
106
107 let cache_dir = resolve_embedding_cache_dir();
108 let options = InitOptions::new(parsed_model).with_cache_dir(cache_dir.clone());
109
110 tracing::info!(
111 target: "tandem.memory",
112 "Initializing embeddings with cache dir: {}",
113 cache_dir.display()
114 );
115
116 match TextEmbedding::try_new(options) {
117 Ok(model) => (Some(model), None),
118 Err(err) => (
119 None,
120 Some(format!(
121 "failed to initialize embedding model '{}': {}",
122 model_name, err
123 )),
124 ),
125 }
126 }
127 }
128
129 #[cfg(feature = "local-embeddings")]
130 fn parse_model_id(model_name: &str) -> Option<EmbeddingModel> {
131 match model_name.trim().to_ascii_lowercase().as_str() {
132 "all-minilm-l6-v2" | "all_minilm_l6_v2" => Some(EmbeddingModel::AllMiniLML6V2),
133 _ => None,
134 }
135 }
136
137 pub fn dimension(&self) -> usize {
139 self.dimension
140 }
141
142 pub fn model_name(&self) -> &str {
144 &self.model_name
145 }
146
147 pub fn is_available(&self) -> bool {
149 self.model.is_some()
150 }
151
152 pub fn disabled_reason(&self) -> Option<&str> {
154 self.disabled_reason.as_deref()
155 }
156
157 fn unavailable_error(&self) -> MemoryError {
158 let reason = self
159 .disabled_reason
160 .as_deref()
161 .unwrap_or("embedding backend unavailable");
162 MemoryError::Embedding(format!("embeddings disabled: {reason}"))
163 }
164
165 fn deterministic_embedding(&self, text: &str) -> Vec<f32> {
166 let mut values = Vec::with_capacity(self.dimension);
167 let mut seed = Sha256::digest(text.as_bytes()).to_vec();
168 while values.len() < self.dimension {
169 for byte in &seed {
170 let value = (*byte as f32 / 127.5) - 1.0;
171 values.push(value);
172 if values.len() == self.dimension {
173 break;
174 }
175 }
176 seed = Sha256::digest(&seed).to_vec();
177 }
178 values
179 }
180
181 #[cfg(feature = "local-embeddings")]
182 fn ensure_dimension(&self, embedding: &[f32]) -> MemoryResult<()> {
183 if embedding.len() != self.dimension {
184 return Err(MemoryError::Embedding(format!(
185 "embedding dimension mismatch: expected {}, got {}",
186 self.dimension,
187 embedding.len()
188 )));
189 }
190 Ok(())
191 }
192
193 pub async fn embed(&self, text: &str) -> MemoryResult<Vec<f32>> {
195 if self.deterministic {
196 return Ok(self.deterministic_embedding(text));
197 }
198
199 #[cfg(not(feature = "local-embeddings"))]
200 {
201 let _ = text;
202 Err(self.unavailable_error())
203 }
204
205 #[cfg(feature = "local-embeddings")]
206 {
207 let Some(model) = self.model.as_ref() else {
208 return Err(self.unavailable_error());
209 };
210
211 let mut embeddings = model
212 .embed(vec![text.to_string()], None)
213 .map_err(|e| MemoryError::Embedding(e.to_string()))?;
214 let embedding = embeddings
215 .pop()
216 .ok_or_else(|| MemoryError::Embedding("no embedding generated".to_string()))?;
217 self.ensure_dimension(&embedding)?;
218 Ok(embedding)
219 }
220 }
221
222 pub async fn embed_batch(&self, texts: &[String]) -> MemoryResult<Vec<Vec<f32>>> {
224 if self.deterministic {
225 return Ok(texts
226 .iter()
227 .map(|text| self.deterministic_embedding(text))
228 .collect());
229 }
230
231 #[cfg(not(feature = "local-embeddings"))]
232 {
233 let _ = texts;
234 Err(self.unavailable_error())
235 }
236
237 #[cfg(feature = "local-embeddings")]
238 {
239 let Some(model) = self.model.as_ref() else {
240 return Err(self.unavailable_error());
241 };
242
243 let embeddings = model
244 .embed(texts.to_vec(), None)
245 .map_err(|e| MemoryError::Embedding(e.to_string()))?;
246
247 for embedding in &embeddings {
248 self.ensure_dimension(embedding)?;
249 }
250
251 Ok(embeddings)
252 }
253 }
254
255 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
257 if a.len() != b.len() {
258 return 0.0;
259 }
260
261 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
262 let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
263 let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
264
265 if magnitude_a == 0.0 || magnitude_b == 0.0 {
266 0.0
267 } else {
268 dot_product / (magnitude_a * magnitude_b)
269 }
270 }
271
272 pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
274 a.iter()
275 .zip(b.iter())
276 .map(|(x, y)| (x - y).powi(2))
277 .sum::<f32>()
278 .sqrt()
279 }
280}
281
282#[cfg(feature = "local-embeddings")]
283fn embeddings_runtime_disabled_reason() -> Option<String> {
284 let disable = std::env::var("TANDEM_DISABLE_EMBEDDINGS")
285 .ok()
286 .map(|v| v.trim().to_ascii_lowercase())
287 .filter(|v| !v.is_empty())
288 .map(|v| matches!(v.as_str(), "1" | "true" | "yes" | "on"))
289 .unwrap_or(false);
290
291 if disable {
292 return Some("disabled by TANDEM_DISABLE_EMBEDDINGS".to_string());
293 }
294 None
295}
296
297#[cfg(feature = "local-embeddings")]
298fn resolve_embedding_cache_dir() -> PathBuf {
299 if let Ok(explicit) = std::env::var("FASTEMBED_CACHE_DIR") {
300 let explicit_path = PathBuf::from(explicit);
301 if let Err(err) = std::fs::create_dir_all(&explicit_path) {
302 tracing::warn!(
303 target: "tandem.memory",
304 "Failed to create FASTEMBED_CACHE_DIR {:?}: {}",
305 explicit_path,
306 err
307 );
308 }
309 return explicit_path;
310 }
311
312 let base = dirs::data_local_dir()
313 .or_else(dirs::cache_dir)
314 .unwrap_or_else(std::env::temp_dir);
315 let cache_dir = base.join("tandem").join("fastembed");
316
317 if let Err(err) = std::fs::create_dir_all(&cache_dir) {
318 tracing::warn!(
319 target: "tandem.memory",
320 "Failed to create embedding cache directory {:?}: {}",
321 cache_dir,
322 err
323 );
324 }
325
326 cache_dir
327}
328
329impl Default for EmbeddingService {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335static EMBEDDING_SERVICE: OnceCell<Arc<Mutex<EmbeddingService>>> = OnceCell::new();
337
338pub async fn get_embedding_service() -> Arc<Mutex<EmbeddingService>> {
340 EMBEDDING_SERVICE
341 .get_or_init(|| Arc::new(Mutex::new(EmbeddingService::new())))
342 .clone()
343}
344
345pub fn init_embedding_service(model_name: Option<String>, dimension: Option<usize>) {
347 let service = if let (Some(name), Some(dim)) = (model_name, dimension) {
348 EmbeddingService::with_model(name, dim)
349 } else {
350 EmbeddingService::new()
351 };
352
353 let _ = EMBEDDING_SERVICE.set(Arc::new(Mutex::new(service)));
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[tokio::test]
361 async fn test_embedding_dimension_or_unavailable() {
362 let service = EmbeddingService::new();
363
364 if !service.is_available() {
365 let err = service.embed("Hello world").await.unwrap_err();
366 assert!(err.to_string().contains("embeddings disabled"));
367 return;
368 }
369
370 let embedding = service.embed("Hello world").await.unwrap();
371 assert_eq!(embedding.len(), DEFAULT_EMBEDDING_DIMENSION);
372 }
373
374 #[tokio::test]
375 async fn test_embed_batch_or_unavailable() {
376 let service = EmbeddingService::new();
377 let texts = vec![
378 "First text".to_string(),
379 "Second text".to_string(),
380 "Third text".to_string(),
381 ];
382
383 let result = service.embed_batch(&texts).await;
384 if !service.is_available() {
385 assert!(result.is_err());
386 return;
387 }
388
389 let embeddings = result.unwrap();
390 assert_eq!(embeddings.len(), 3);
391 for emb in &embeddings {
392 assert_eq!(emb.len(), DEFAULT_EMBEDDING_DIMENSION);
393 }
394 }
395
396 #[test]
397 fn test_cosine_similarity() {
398 let a = vec![1.0f32, 0.0, 0.0];
399 let b = vec![1.0f32, 0.0, 0.0];
400 let c = vec![0.0f32, 1.0, 0.0];
401
402 let sim_same = EmbeddingService::cosine_similarity(&a, &b);
403 let sim_orthogonal = EmbeddingService::cosine_similarity(&a, &c);
404
405 assert!((sim_same - 1.0).abs() < 1e-6);
406 assert!(sim_orthogonal.abs() < 1e-6);
407 }
408}