Skip to main content

tandem_memory/
embeddings.rs

1// Embedding Service Module
2// Generates embeddings using local fastembed implementation.
3
4use crate::types::{
5    MemoryError, MemoryResult, DEFAULT_EMBEDDING_DIMENSION, DEFAULT_EMBEDDING_MODEL,
6};
7#[cfg(feature = "local-embeddings")]
8use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
9use once_cell::sync::OnceCell;
10use sha2::{Digest, Sha256};
11#[cfg(feature = "local-embeddings")]
12use std::path::PathBuf;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15
16#[cfg(feature = "local-embeddings")]
17type EmbeddingBackend = TextEmbedding;
18#[cfg(not(feature = "local-embeddings"))]
19type EmbeddingBackend = ();
20
21/// Embedding service for generating vector representations.
22pub struct EmbeddingService {
23    model_name: String,
24    dimension: usize,
25    model: Option<EmbeddingBackend>,
26    disabled_reason: Option<String>,
27    deterministic: bool,
28}
29
30impl EmbeddingService {
31    /// Create a new embedding service with default model.
32    pub fn new() -> Self {
33        Self::with_model(
34            DEFAULT_EMBEDDING_MODEL.to_string(),
35            DEFAULT_EMBEDDING_DIMENSION,
36        )
37    }
38
39    /// Create with custom model.
40    pub fn with_model(model_name: String, dimension: usize) -> Self {
41        let (model, disabled_reason) = Self::init_model(&model_name);
42
43        if let Some(reason) = &disabled_reason {
44            tracing::warn!(
45                target: "tandem.memory",
46                "Embeddings disabled: model={} reason={}",
47                model_name,
48                reason
49            );
50        } else {
51            tracing::info!(
52                target: "tandem.memory",
53                "Embeddings enabled: model={} dimension={}",
54                model_name,
55                dimension
56            );
57        }
58
59        Self {
60            model_name,
61            dimension,
62            model,
63            disabled_reason,
64            deterministic: false,
65        }
66    }
67
68    /// Create a lightweight deterministic embedding service for tests
69    /// and offline callers that need vector-shaped data without loading
70    /// the local ONNX model.
71    pub fn deterministic_for_tests(dimension: usize) -> Self {
72        Self {
73            model_name: "deterministic-test-embedding".to_string(),
74            dimension,
75            model: None,
76            disabled_reason: None,
77            deterministic: true,
78        }
79    }
80
81    fn init_model(model_name: &str) -> (Option<EmbeddingBackend>, Option<String>) {
82        #[cfg(not(feature = "local-embeddings"))]
83        {
84            let _ = model_name;
85            (
86                None,
87                Some("local embeddings are disabled at build time".to_string()),
88            )
89        }
90
91        #[cfg(feature = "local-embeddings")]
92        {
93            if let Some(reason) = embeddings_runtime_disabled_reason() {
94                return (None, Some(reason));
95            }
96
97            let Some(parsed_model) = Self::parse_model_id(model_name) else {
98                return (
99                    None,
100                    Some(format!(
101                        "unsupported embedding model id '{}'; supported: {}",
102                        model_name, DEFAULT_EMBEDDING_MODEL
103                    )),
104                );
105            };
106
107            let cache_dir = resolve_embedding_cache_dir();
108            let options = InitOptions::new(parsed_model).with_cache_dir(cache_dir.clone());
109
110            tracing::info!(
111                target: "tandem.memory",
112                "Initializing embeddings with cache dir: {}",
113                cache_dir.display()
114            );
115
116            match TextEmbedding::try_new(options) {
117                Ok(model) => (Some(model), None),
118                Err(err) => (
119                    None,
120                    Some(format!(
121                        "failed to initialize embedding model '{}': {}",
122                        model_name, err
123                    )),
124                ),
125            }
126        }
127    }
128
129    #[cfg(feature = "local-embeddings")]
130    fn parse_model_id(model_name: &str) -> Option<EmbeddingModel> {
131        match model_name.trim().to_ascii_lowercase().as_str() {
132            "all-minilm-l6-v2" | "all_minilm_l6_v2" => Some(EmbeddingModel::AllMiniLML6V2),
133            _ => None,
134        }
135    }
136
137    /// Get the embedding dimension.
138    pub fn dimension(&self) -> usize {
139        self.dimension
140    }
141
142    /// Get the model name.
143    pub fn model_name(&self) -> &str {
144        &self.model_name
145    }
146
147    /// Returns whether semantic embeddings are currently available.
148    pub fn is_available(&self) -> bool {
149        self.model.is_some()
150    }
151
152    /// Returns disabled reason if embeddings are unavailable.
153    pub fn disabled_reason(&self) -> Option<&str> {
154        self.disabled_reason.as_deref()
155    }
156
157    fn unavailable_error(&self) -> MemoryError {
158        let reason = self
159            .disabled_reason
160            .as_deref()
161            .unwrap_or("embedding backend unavailable");
162        MemoryError::Embedding(format!("embeddings disabled: {reason}"))
163    }
164
165    fn deterministic_embedding(&self, text: &str) -> Vec<f32> {
166        let mut values = Vec::with_capacity(self.dimension);
167        let mut seed = Sha256::digest(text.as_bytes()).to_vec();
168        while values.len() < self.dimension {
169            for byte in &seed {
170                let value = (*byte as f32 / 127.5) - 1.0;
171                values.push(value);
172                if values.len() == self.dimension {
173                    break;
174                }
175            }
176            seed = Sha256::digest(&seed).to_vec();
177        }
178        values
179    }
180
181    #[cfg(feature = "local-embeddings")]
182    fn ensure_dimension(&self, embedding: &[f32]) -> MemoryResult<()> {
183        if embedding.len() != self.dimension {
184            return Err(MemoryError::Embedding(format!(
185                "embedding dimension mismatch: expected {}, got {}",
186                self.dimension,
187                embedding.len()
188            )));
189        }
190        Ok(())
191    }
192
193    /// Generate embeddings for a single text.
194    pub async fn embed(&self, text: &str) -> MemoryResult<Vec<f32>> {
195        if self.deterministic {
196            return Ok(self.deterministic_embedding(text));
197        }
198
199        #[cfg(not(feature = "local-embeddings"))]
200        {
201            let _ = text;
202            Err(self.unavailable_error())
203        }
204
205        #[cfg(feature = "local-embeddings")]
206        {
207            let Some(model) = self.model.as_ref() else {
208                return Err(self.unavailable_error());
209            };
210
211            let mut embeddings = model
212                .embed(vec![text.to_string()], None)
213                .map_err(|e| MemoryError::Embedding(e.to_string()))?;
214            let embedding = embeddings
215                .pop()
216                .ok_or_else(|| MemoryError::Embedding("no embedding generated".to_string()))?;
217            self.ensure_dimension(&embedding)?;
218            Ok(embedding)
219        }
220    }
221
222    /// Generate embeddings for multiple texts.
223    pub async fn embed_batch(&self, texts: &[String]) -> MemoryResult<Vec<Vec<f32>>> {
224        if self.deterministic {
225            return Ok(texts
226                .iter()
227                .map(|text| self.deterministic_embedding(text))
228                .collect());
229        }
230
231        #[cfg(not(feature = "local-embeddings"))]
232        {
233            let _ = texts;
234            Err(self.unavailable_error())
235        }
236
237        #[cfg(feature = "local-embeddings")]
238        {
239            let Some(model) = self.model.as_ref() else {
240                return Err(self.unavailable_error());
241            };
242
243            let embeddings = model
244                .embed(texts.to_vec(), None)
245                .map_err(|e| MemoryError::Embedding(e.to_string()))?;
246
247            for embedding in &embeddings {
248                self.ensure_dimension(embedding)?;
249            }
250
251            Ok(embeddings)
252        }
253    }
254
255    /// Calculate cosine similarity between two vectors.
256    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
257        if a.len() != b.len() {
258            return 0.0;
259        }
260
261        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
262        let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
263        let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
264
265        if magnitude_a == 0.0 || magnitude_b == 0.0 {
266            0.0
267        } else {
268            dot_product / (magnitude_a * magnitude_b)
269        }
270    }
271
272    /// Calculate Euclidean distance between two vectors.
273    pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
274        a.iter()
275            .zip(b.iter())
276            .map(|(x, y)| (x - y).powi(2))
277            .sum::<f32>()
278            .sqrt()
279    }
280}
281
282#[cfg(feature = "local-embeddings")]
283fn embeddings_runtime_disabled_reason() -> Option<String> {
284    let disable = std::env::var("TANDEM_DISABLE_EMBEDDINGS")
285        .ok()
286        .map(|v| v.trim().to_ascii_lowercase())
287        .filter(|v| !v.is_empty())
288        .map(|v| matches!(v.as_str(), "1" | "true" | "yes" | "on"))
289        .unwrap_or(false);
290
291    if disable {
292        return Some("disabled by TANDEM_DISABLE_EMBEDDINGS".to_string());
293    }
294    None
295}
296
297#[cfg(feature = "local-embeddings")]
298fn resolve_embedding_cache_dir() -> PathBuf {
299    if let Ok(explicit) = std::env::var("FASTEMBED_CACHE_DIR") {
300        let explicit_path = PathBuf::from(explicit);
301        if let Err(err) = std::fs::create_dir_all(&explicit_path) {
302            tracing::warn!(
303                target: "tandem.memory",
304                "Failed to create FASTEMBED_CACHE_DIR {:?}: {}",
305                explicit_path,
306                err
307            );
308        }
309        return explicit_path;
310    }
311
312    let base = dirs::data_local_dir()
313        .or_else(dirs::cache_dir)
314        .unwrap_or_else(std::env::temp_dir);
315    let cache_dir = base.join("tandem").join("fastembed");
316
317    if let Err(err) = std::fs::create_dir_all(&cache_dir) {
318        tracing::warn!(
319            target: "tandem.memory",
320            "Failed to create embedding cache directory {:?}: {}",
321            cache_dir,
322            err
323        );
324    }
325
326    cache_dir
327}
328
329impl Default for EmbeddingService {
330    fn default() -> Self {
331        Self::new()
332    }
333}
334
335// Global embedding service instance.
336static EMBEDDING_SERVICE: OnceCell<Arc<Mutex<EmbeddingService>>> = OnceCell::new();
337
338/// Get or initialize the global embedding service.
339pub async fn get_embedding_service() -> Arc<Mutex<EmbeddingService>> {
340    EMBEDDING_SERVICE
341        .get_or_init(|| Arc::new(Mutex::new(EmbeddingService::new())))
342        .clone()
343}
344
345/// Initialize the embedding service with custom configuration.
346pub fn init_embedding_service(model_name: Option<String>, dimension: Option<usize>) {
347    let service = if let (Some(name), Some(dim)) = (model_name, dimension) {
348        EmbeddingService::with_model(name, dim)
349    } else {
350        EmbeddingService::new()
351    };
352
353    let _ = EMBEDDING_SERVICE.set(Arc::new(Mutex::new(service)));
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[tokio::test]
361    async fn test_embedding_dimension_or_unavailable() {
362        let service = EmbeddingService::new();
363
364        if !service.is_available() {
365            let err = service.embed("Hello world").await.unwrap_err();
366            assert!(err.to_string().contains("embeddings disabled"));
367            return;
368        }
369
370        let embedding = service.embed("Hello world").await.unwrap();
371        assert_eq!(embedding.len(), DEFAULT_EMBEDDING_DIMENSION);
372    }
373
374    #[tokio::test]
375    async fn test_embed_batch_or_unavailable() {
376        let service = EmbeddingService::new();
377        let texts = vec![
378            "First text".to_string(),
379            "Second text".to_string(),
380            "Third text".to_string(),
381        ];
382
383        let result = service.embed_batch(&texts).await;
384        if !service.is_available() {
385            assert!(result.is_err());
386            return;
387        }
388
389        let embeddings = result.unwrap();
390        assert_eq!(embeddings.len(), 3);
391        for emb in &embeddings {
392            assert_eq!(emb.len(), DEFAULT_EMBEDDING_DIMENSION);
393        }
394    }
395
396    #[test]
397    fn test_cosine_similarity() {
398        let a = vec![1.0f32, 0.0, 0.0];
399        let b = vec![1.0f32, 0.0, 0.0];
400        let c = vec![0.0f32, 1.0, 0.0];
401
402        let sim_same = EmbeddingService::cosine_similarity(&a, &b);
403        let sim_orthogonal = EmbeddingService::cosine_similarity(&a, &c);
404
405        assert!((sim_same - 1.0).abs() < 1e-6);
406        assert!(sim_orthogonal.abs() < 1e-6);
407    }
408}