Skip to main content

tandem_memory/
embeddings.rs

1// Embedding Service Module
2// Generates embeddings using local fastembed implementation.
3
4use crate::types::{
5    MemoryError, MemoryResult, DEFAULT_EMBEDDING_DIMENSION, DEFAULT_EMBEDDING_MODEL,
6};
7use once_cell::sync::OnceCell;
8#[cfg(feature = "local-embeddings")]
9use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
10#[cfg(feature = "local-embeddings")]
11use std::path::PathBuf;
12use std::sync::Arc;
13use tokio::sync::Mutex;
14
15#[cfg(feature = "local-embeddings")]
16type EmbeddingBackend = TextEmbedding;
17#[cfg(not(feature = "local-embeddings"))]
18type EmbeddingBackend = ();
19
20/// Embedding service for generating vector representations.
21pub struct EmbeddingService {
22    model_name: String,
23    dimension: usize,
24    model: Option<EmbeddingBackend>,
25    disabled_reason: Option<String>,
26}
27
28impl EmbeddingService {
29    /// Create a new embedding service with default model.
30    pub fn new() -> Self {
31        Self::with_model(
32            DEFAULT_EMBEDDING_MODEL.to_string(),
33            DEFAULT_EMBEDDING_DIMENSION,
34        )
35    }
36
37    /// Create with custom model.
38    pub fn with_model(model_name: String, dimension: usize) -> Self {
39        let (model, disabled_reason) = Self::init_model(&model_name);
40
41        if let Some(reason) = &disabled_reason {
42            tracing::warn!(
43                target: "tandem.memory",
44                "Embeddings disabled: model={} reason={}",
45                model_name,
46                reason
47            );
48        } else {
49            tracing::info!(
50                target: "tandem.memory",
51                "Embeddings enabled: model={} dimension={}",
52                model_name,
53                dimension
54            );
55        }
56
57        Self {
58            model_name,
59            dimension,
60            model,
61            disabled_reason,
62        }
63    }
64
65    fn init_model(model_name: &str) -> (Option<EmbeddingBackend>, Option<String>) {
66        #[cfg(not(feature = "local-embeddings"))]
67        {
68            let _ = model_name;
69            return (
70                None,
71                Some("local embeddings are disabled at build time".to_string()),
72            );
73        }
74
75        #[cfg(feature = "local-embeddings")]
76        {
77        let Some(parsed_model) = Self::parse_model_id(model_name) else {
78            return (
79                None,
80                Some(format!(
81                    "unsupported embedding model id '{}'; supported: {}",
82                    model_name, DEFAULT_EMBEDDING_MODEL
83                )),
84            );
85        };
86
87        let cache_dir = resolve_embedding_cache_dir();
88        let options = InitOptions::new(parsed_model).with_cache_dir(cache_dir.clone());
89
90        tracing::info!(
91            target: "tandem.memory",
92            "Initializing embeddings with cache dir: {}",
93            cache_dir.display()
94        );
95
96        match TextEmbedding::try_new(options) {
97            Ok(model) => (Some(model), None),
98            Err(err) => (
99                None,
100                Some(format!(
101                    "failed to initialize embedding model '{}': {}",
102                    model_name, err
103                )),
104            ),
105        }
106        }
107    }
108
109    #[cfg(feature = "local-embeddings")]
110    fn parse_model_id(model_name: &str) -> Option<EmbeddingModel> {
111        match model_name.trim().to_ascii_lowercase().as_str() {
112            "all-minilm-l6-v2" | "all_minilm_l6_v2" => Some(EmbeddingModel::AllMiniLML6V2),
113            _ => None,
114        }
115    }
116
117    /// Get the embedding dimension.
118    pub fn dimension(&self) -> usize {
119        self.dimension
120    }
121
122    /// Get the model name.
123    pub fn model_name(&self) -> &str {
124        &self.model_name
125    }
126
127    /// Returns whether semantic embeddings are currently available.
128    pub fn is_available(&self) -> bool {
129        self.model.is_some()
130    }
131
132    /// Returns disabled reason if embeddings are unavailable.
133    pub fn disabled_reason(&self) -> Option<&str> {
134        self.disabled_reason.as_deref()
135    }
136
137    fn unavailable_error(&self) -> MemoryError {
138        let reason = self
139            .disabled_reason
140            .as_deref()
141            .unwrap_or("embedding backend unavailable");
142        MemoryError::Embedding(format!("embeddings disabled: {reason}"))
143    }
144
145    #[cfg(feature = "local-embeddings")]
146    fn ensure_dimension(&self, embedding: &[f32]) -> MemoryResult<()> {
147        if embedding.len() != self.dimension {
148            return Err(MemoryError::Embedding(format!(
149                "embedding dimension mismatch: expected {}, got {}",
150                self.dimension,
151                embedding.len()
152            )));
153        }
154        Ok(())
155    }
156
157    /// Generate embeddings for a single text.
158    pub async fn embed(&self, text: &str) -> MemoryResult<Vec<f32>> {
159        #[cfg(not(feature = "local-embeddings"))]
160        {
161            let _ = text;
162            return Err(self.unavailable_error());
163        }
164
165        #[cfg(feature = "local-embeddings")]
166        {
167        let Some(model) = self.model.as_ref() else {
168            return Err(self.unavailable_error());
169        };
170
171        let mut embeddings = model
172            .embed(vec![text.to_string()], None)
173            .map_err(|e| MemoryError::Embedding(e.to_string()))?;
174        let embedding = embeddings
175            .pop()
176            .ok_or_else(|| MemoryError::Embedding("no embedding generated".to_string()))?;
177        self.ensure_dimension(&embedding)?;
178        Ok(embedding)
179        }
180    }
181
182    /// Generate embeddings for multiple texts.
183    pub async fn embed_batch(&self, texts: &[String]) -> MemoryResult<Vec<Vec<f32>>> {
184        #[cfg(not(feature = "local-embeddings"))]
185        {
186            let _ = texts;
187            return Err(self.unavailable_error());
188        }
189
190        #[cfg(feature = "local-embeddings")]
191        {
192        let Some(model) = self.model.as_ref() else {
193            return Err(self.unavailable_error());
194        };
195
196        let embeddings = model
197            .embed(texts.to_vec(), None)
198            .map_err(|e| MemoryError::Embedding(e.to_string()))?;
199
200        for embedding in &embeddings {
201            self.ensure_dimension(embedding)?;
202        }
203
204        Ok(embeddings)
205        }
206    }
207
208    /// Calculate cosine similarity between two vectors.
209    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
210        if a.len() != b.len() {
211            return 0.0;
212        }
213
214        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
215        let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
216        let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
217
218        if magnitude_a == 0.0 || magnitude_b == 0.0 {
219            0.0
220        } else {
221            dot_product / (magnitude_a * magnitude_b)
222        }
223    }
224
225    /// Calculate Euclidean distance between two vectors.
226    pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
227        a.iter()
228            .zip(b.iter())
229            .map(|(x, y)| (x - y).powi(2))
230            .sum::<f32>()
231            .sqrt()
232    }
233}
234
235#[cfg(feature = "local-embeddings")]
236fn resolve_embedding_cache_dir() -> PathBuf {
237    if let Ok(explicit) = std::env::var("FASTEMBED_CACHE_DIR") {
238        let explicit_path = PathBuf::from(explicit);
239        if let Err(err) = std::fs::create_dir_all(&explicit_path) {
240            tracing::warn!(
241                target: "tandem.memory",
242                "Failed to create FASTEMBED_CACHE_DIR {:?}: {}",
243                explicit_path,
244                err
245            );
246        }
247        return explicit_path;
248    }
249
250    let base = dirs::data_local_dir()
251        .or_else(dirs::cache_dir)
252        .unwrap_or_else(std::env::temp_dir);
253    let cache_dir = base.join("tandem").join("fastembed");
254
255    if let Err(err) = std::fs::create_dir_all(&cache_dir) {
256        tracing::warn!(
257            target: "tandem.memory",
258            "Failed to create embedding cache directory {:?}: {}",
259            cache_dir,
260            err
261        );
262    }
263
264    cache_dir
265}
266
267impl Default for EmbeddingService {
268    fn default() -> Self {
269        Self::new()
270    }
271}
272
273// Global embedding service instance.
274static EMBEDDING_SERVICE: OnceCell<Arc<Mutex<EmbeddingService>>> = OnceCell::new();
275
276/// Get or initialize the global embedding service.
277pub async fn get_embedding_service() -> Arc<Mutex<EmbeddingService>> {
278    EMBEDDING_SERVICE
279        .get_or_init(|| Arc::new(Mutex::new(EmbeddingService::new())))
280        .clone()
281}
282
283/// Initialize the embedding service with custom configuration.
284pub fn init_embedding_service(model_name: Option<String>, dimension: Option<usize>) {
285    let service = if let (Some(name), Some(dim)) = (model_name, dimension) {
286        EmbeddingService::with_model(name, dim)
287    } else {
288        EmbeddingService::new()
289    };
290
291    let _ = EMBEDDING_SERVICE.set(Arc::new(Mutex::new(service)));
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[tokio::test]
299    async fn test_embedding_dimension_or_unavailable() {
300        let service = EmbeddingService::new();
301
302        if !service.is_available() {
303            let err = service.embed("Hello world").await.unwrap_err();
304            assert!(err.to_string().contains("embeddings disabled"));
305            return;
306        }
307
308        let embedding = service.embed("Hello world").await.unwrap();
309        assert_eq!(embedding.len(), DEFAULT_EMBEDDING_DIMENSION);
310    }
311
312    #[tokio::test]
313    async fn test_embed_batch_or_unavailable() {
314        let service = EmbeddingService::new();
315        let texts = vec![
316            "First text".to_string(),
317            "Second text".to_string(),
318            "Third text".to_string(),
319        ];
320
321        let result = service.embed_batch(&texts).await;
322        if !service.is_available() {
323            assert!(result.is_err());
324            return;
325        }
326
327        let embeddings = result.unwrap();
328        assert_eq!(embeddings.len(), 3);
329        for emb in &embeddings {
330            assert_eq!(emb.len(), DEFAULT_EMBEDDING_DIMENSION);
331        }
332    }
333
334    #[test]
335    fn test_cosine_similarity() {
336        let a = vec![1.0f32, 0.0, 0.0];
337        let b = vec![1.0f32, 0.0, 0.0];
338        let c = vec![0.0f32, 1.0, 0.0];
339
340        let sim_same = EmbeddingService::cosine_similarity(&a, &b);
341        let sim_orthogonal = EmbeddingService::cosine_similarity(&a, &c);
342
343        assert!((sim_same - 1.0).abs() < 1e-6);
344        assert!(sim_orthogonal.abs() < 1e-6);
345    }
346}