Skip to main content

tandem_memory/
embeddings.rs

1// Embedding Service Module
2// Generates embeddings using local fastembed implementation.
3
4use crate::types::{
5    MemoryError, MemoryResult, DEFAULT_EMBEDDING_DIMENSION, DEFAULT_EMBEDDING_MODEL,
6};
7#[cfg(feature = "local-embeddings")]
8use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
9use once_cell::sync::OnceCell;
10#[cfg(feature = "local-embeddings")]
11use std::path::PathBuf;
12use std::sync::Arc;
13use tokio::sync::Mutex;
14
15#[cfg(feature = "local-embeddings")]
16type EmbeddingBackend = TextEmbedding;
17#[cfg(not(feature = "local-embeddings"))]
18type EmbeddingBackend = ();
19
20/// Embedding service for generating vector representations.
21pub struct EmbeddingService {
22    model_name: String,
23    dimension: usize,
24    model: Option<EmbeddingBackend>,
25    disabled_reason: Option<String>,
26}
27
28impl EmbeddingService {
29    /// Create a new embedding service with default model.
30    pub fn new() -> Self {
31        Self::with_model(
32            DEFAULT_EMBEDDING_MODEL.to_string(),
33            DEFAULT_EMBEDDING_DIMENSION,
34        )
35    }
36
37    /// Create with custom model.
38    pub fn with_model(model_name: String, dimension: usize) -> Self {
39        let (model, disabled_reason) = Self::init_model(&model_name);
40
41        if let Some(reason) = &disabled_reason {
42            tracing::warn!(
43                target: "tandem.memory",
44                "Embeddings disabled: model={} reason={}",
45                model_name,
46                reason
47            );
48        } else {
49            tracing::info!(
50                target: "tandem.memory",
51                "Embeddings enabled: model={} dimension={}",
52                model_name,
53                dimension
54            );
55        }
56
57        Self {
58            model_name,
59            dimension,
60            model,
61            disabled_reason,
62        }
63    }
64
65    fn init_model(model_name: &str) -> (Option<EmbeddingBackend>, Option<String>) {
66        #[cfg(not(feature = "local-embeddings"))]
67        {
68            let _ = model_name;
69            return (
70                None,
71                Some("local embeddings are disabled at build time".to_string()),
72            );
73        }
74
75        #[cfg(feature = "local-embeddings")]
76        {
77            if let Some(reason) = embeddings_runtime_disabled_reason() {
78                return (None, Some(reason));
79            }
80
81            let Some(parsed_model) = Self::parse_model_id(model_name) else {
82                return (
83                    None,
84                    Some(format!(
85                        "unsupported embedding model id '{}'; supported: {}",
86                        model_name, DEFAULT_EMBEDDING_MODEL
87                    )),
88                );
89            };
90
91            let cache_dir = resolve_embedding_cache_dir();
92            let options = InitOptions::new(parsed_model).with_cache_dir(cache_dir.clone());
93
94            tracing::info!(
95                target: "tandem.memory",
96                "Initializing embeddings with cache dir: {}",
97                cache_dir.display()
98            );
99
100            match TextEmbedding::try_new(options) {
101                Ok(model) => (Some(model), None),
102                Err(err) => (
103                    None,
104                    Some(format!(
105                        "failed to initialize embedding model '{}': {}",
106                        model_name, err
107                    )),
108                ),
109            }
110        }
111    }
112
113    #[cfg(feature = "local-embeddings")]
114    fn parse_model_id(model_name: &str) -> Option<EmbeddingModel> {
115        match model_name.trim().to_ascii_lowercase().as_str() {
116            "all-minilm-l6-v2" | "all_minilm_l6_v2" => Some(EmbeddingModel::AllMiniLML6V2),
117            _ => None,
118        }
119    }
120
121    /// Get the embedding dimension.
122    pub fn dimension(&self) -> usize {
123        self.dimension
124    }
125
126    /// Get the model name.
127    pub fn model_name(&self) -> &str {
128        &self.model_name
129    }
130
131    /// Returns whether semantic embeddings are currently available.
132    pub fn is_available(&self) -> bool {
133        self.model.is_some()
134    }
135
136    /// Returns disabled reason if embeddings are unavailable.
137    pub fn disabled_reason(&self) -> Option<&str> {
138        self.disabled_reason.as_deref()
139    }
140
141    fn unavailable_error(&self) -> MemoryError {
142        let reason = self
143            .disabled_reason
144            .as_deref()
145            .unwrap_or("embedding backend unavailable");
146        MemoryError::Embedding(format!("embeddings disabled: {reason}"))
147    }
148
149    #[cfg(feature = "local-embeddings")]
150    fn ensure_dimension(&self, embedding: &[f32]) -> MemoryResult<()> {
151        if embedding.len() != self.dimension {
152            return Err(MemoryError::Embedding(format!(
153                "embedding dimension mismatch: expected {}, got {}",
154                self.dimension,
155                embedding.len()
156            )));
157        }
158        Ok(())
159    }
160
161    /// Generate embeddings for a single text.
162    pub async fn embed(&self, text: &str) -> MemoryResult<Vec<f32>> {
163        #[cfg(not(feature = "local-embeddings"))]
164        {
165            let _ = text;
166            return Err(self.unavailable_error());
167        }
168
169        #[cfg(feature = "local-embeddings")]
170        {
171            let Some(model) = self.model.as_ref() else {
172                return Err(self.unavailable_error());
173            };
174
175            let mut embeddings = model
176                .embed(vec![text.to_string()], None)
177                .map_err(|e| MemoryError::Embedding(e.to_string()))?;
178            let embedding = embeddings
179                .pop()
180                .ok_or_else(|| MemoryError::Embedding("no embedding generated".to_string()))?;
181            self.ensure_dimension(&embedding)?;
182            Ok(embedding)
183        }
184    }
185
186    /// Generate embeddings for multiple texts.
187    pub async fn embed_batch(&self, texts: &[String]) -> MemoryResult<Vec<Vec<f32>>> {
188        #[cfg(not(feature = "local-embeddings"))]
189        {
190            let _ = texts;
191            return Err(self.unavailable_error());
192        }
193
194        #[cfg(feature = "local-embeddings")]
195        {
196            let Some(model) = self.model.as_ref() else {
197                return Err(self.unavailable_error());
198            };
199
200            let embeddings = model
201                .embed(texts.to_vec(), None)
202                .map_err(|e| MemoryError::Embedding(e.to_string()))?;
203
204            for embedding in &embeddings {
205                self.ensure_dimension(embedding)?;
206            }
207
208            Ok(embeddings)
209        }
210    }
211
212    /// Calculate cosine similarity between two vectors.
213    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
214        if a.len() != b.len() {
215            return 0.0;
216        }
217
218        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
219        let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
220        let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
221
222        if magnitude_a == 0.0 || magnitude_b == 0.0 {
223            0.0
224        } else {
225            dot_product / (magnitude_a * magnitude_b)
226        }
227    }
228
229    /// Calculate Euclidean distance between two vectors.
230    pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
231        a.iter()
232            .zip(b.iter())
233            .map(|(x, y)| (x - y).powi(2))
234            .sum::<f32>()
235            .sqrt()
236    }
237}
238
239#[cfg(feature = "local-embeddings")]
240fn embeddings_runtime_disabled_reason() -> Option<String> {
241    let disable = std::env::var("TANDEM_DISABLE_EMBEDDINGS")
242        .ok()
243        .map(|v| v.trim().to_ascii_lowercase())
244        .filter(|v| !v.is_empty())
245        .map(|v| matches!(v.as_str(), "1" | "true" | "yes" | "on"))
246        .unwrap_or(false);
247
248    if disable {
249        return Some("disabled by TANDEM_DISABLE_EMBEDDINGS".to_string());
250    }
251    None
252}
253
254#[cfg(feature = "local-embeddings")]
255fn resolve_embedding_cache_dir() -> PathBuf {
256    if let Ok(explicit) = std::env::var("FASTEMBED_CACHE_DIR") {
257        let explicit_path = PathBuf::from(explicit);
258        if let Err(err) = std::fs::create_dir_all(&explicit_path) {
259            tracing::warn!(
260                target: "tandem.memory",
261                "Failed to create FASTEMBED_CACHE_DIR {:?}: {}",
262                explicit_path,
263                err
264            );
265        }
266        return explicit_path;
267    }
268
269    let base = dirs::data_local_dir()
270        .or_else(dirs::cache_dir)
271        .unwrap_or_else(std::env::temp_dir);
272    let cache_dir = base.join("tandem").join("fastembed");
273
274    if let Err(err) = std::fs::create_dir_all(&cache_dir) {
275        tracing::warn!(
276            target: "tandem.memory",
277            "Failed to create embedding cache directory {:?}: {}",
278            cache_dir,
279            err
280        );
281    }
282
283    cache_dir
284}
285
286impl Default for EmbeddingService {
287    fn default() -> Self {
288        Self::new()
289    }
290}
291
292// Global embedding service instance.
293static EMBEDDING_SERVICE: OnceCell<Arc<Mutex<EmbeddingService>>> = OnceCell::new();
294
295/// Get or initialize the global embedding service.
296pub async fn get_embedding_service() -> Arc<Mutex<EmbeddingService>> {
297    EMBEDDING_SERVICE
298        .get_or_init(|| Arc::new(Mutex::new(EmbeddingService::new())))
299        .clone()
300}
301
302/// Initialize the embedding service with custom configuration.
303pub fn init_embedding_service(model_name: Option<String>, dimension: Option<usize>) {
304    let service = if let (Some(name), Some(dim)) = (model_name, dimension) {
305        EmbeddingService::with_model(name, dim)
306    } else {
307        EmbeddingService::new()
308    };
309
310    let _ = EMBEDDING_SERVICE.set(Arc::new(Mutex::new(service)));
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[tokio::test]
318    async fn test_embedding_dimension_or_unavailable() {
319        let service = EmbeddingService::new();
320
321        if !service.is_available() {
322            let err = service.embed("Hello world").await.unwrap_err();
323            assert!(err.to_string().contains("embeddings disabled"));
324            return;
325        }
326
327        let embedding = service.embed("Hello world").await.unwrap();
328        assert_eq!(embedding.len(), DEFAULT_EMBEDDING_DIMENSION);
329    }
330
331    #[tokio::test]
332    async fn test_embed_batch_or_unavailable() {
333        let service = EmbeddingService::new();
334        let texts = vec![
335            "First text".to_string(),
336            "Second text".to_string(),
337            "Third text".to_string(),
338        ];
339
340        let result = service.embed_batch(&texts).await;
341        if !service.is_available() {
342            assert!(result.is_err());
343            return;
344        }
345
346        let embeddings = result.unwrap();
347        assert_eq!(embeddings.len(), 3);
348        for emb in &embeddings {
349            assert_eq!(emb.len(), DEFAULT_EMBEDDING_DIMENSION);
350        }
351    }
352
353    #[test]
354    fn test_cosine_similarity() {
355        let a = vec![1.0f32, 0.0, 0.0];
356        let b = vec![1.0f32, 0.0, 0.0];
357        let c = vec![0.0f32, 1.0, 0.0];
358
359        let sim_same = EmbeddingService::cosine_similarity(&a, &b);
360        let sim_orthogonal = EmbeddingService::cosine_similarity(&a, &c);
361
362        assert!((sim_same - 1.0).abs() < 1e-6);
363        assert!(sim_orthogonal.abs() < 1e-6);
364    }
365}