Skip to main content

tandem_memory/
embeddings.rs

1// Embedding Service Module
2// Generates embeddings using local fastembed implementation.
3
4use crate::types::{
5    MemoryError, MemoryResult, DEFAULT_EMBEDDING_DIMENSION, DEFAULT_EMBEDDING_MODEL,
6};
7use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
8use once_cell::sync::OnceCell;
9use std::path::PathBuf;
10use std::sync::Arc;
11use tokio::sync::Mutex;
12
13/// Embedding service for generating vector representations.
14pub struct EmbeddingService {
15    model_name: String,
16    dimension: usize,
17    model: Option<TextEmbedding>,
18    disabled_reason: Option<String>,
19}
20
21impl EmbeddingService {
22    /// Create a new embedding service with default model.
23    pub fn new() -> Self {
24        Self::with_model(
25            DEFAULT_EMBEDDING_MODEL.to_string(),
26            DEFAULT_EMBEDDING_DIMENSION,
27        )
28    }
29
30    /// Create with custom model.
31    pub fn with_model(model_name: String, dimension: usize) -> Self {
32        let (model, disabled_reason) = Self::init_model(&model_name);
33
34        if let Some(reason) = &disabled_reason {
35            tracing::warn!(
36                target: "tandem.memory",
37                "Embeddings disabled: model={} reason={}",
38                model_name,
39                reason
40            );
41        } else {
42            tracing::info!(
43                target: "tandem.memory",
44                "Embeddings enabled: model={} dimension={}",
45                model_name,
46                dimension
47            );
48        }
49
50        Self {
51            model_name,
52            dimension,
53            model,
54            disabled_reason,
55        }
56    }
57
58    fn init_model(model_name: &str) -> (Option<TextEmbedding>, Option<String>) {
59        let Some(parsed_model) = Self::parse_model_id(model_name) else {
60            return (
61                None,
62                Some(format!(
63                    "unsupported embedding model id '{}'; supported: {}",
64                    model_name, DEFAULT_EMBEDDING_MODEL
65                )),
66            );
67        };
68
69        let cache_dir = resolve_embedding_cache_dir();
70        let options = InitOptions::new(parsed_model).with_cache_dir(cache_dir.clone());
71
72        tracing::info!(
73            target: "tandem.memory",
74            "Initializing embeddings with cache dir: {}",
75            cache_dir.display()
76        );
77
78        match TextEmbedding::try_new(options) {
79            Ok(model) => (Some(model), None),
80            Err(err) => (
81                None,
82                Some(format!(
83                    "failed to initialize embedding model '{}': {}",
84                    model_name, err
85                )),
86            ),
87        }
88    }
89
90    fn parse_model_id(model_name: &str) -> Option<EmbeddingModel> {
91        match model_name.trim().to_ascii_lowercase().as_str() {
92            "all-minilm-l6-v2" | "all_minilm_l6_v2" => Some(EmbeddingModel::AllMiniLML6V2),
93            _ => None,
94        }
95    }
96
97    /// Get the embedding dimension.
98    pub fn dimension(&self) -> usize {
99        self.dimension
100    }
101
102    /// Get the model name.
103    pub fn model_name(&self) -> &str {
104        &self.model_name
105    }
106
107    /// Returns whether semantic embeddings are currently available.
108    pub fn is_available(&self) -> bool {
109        self.model.is_some()
110    }
111
112    /// Returns disabled reason if embeddings are unavailable.
113    pub fn disabled_reason(&self) -> Option<&str> {
114        self.disabled_reason.as_deref()
115    }
116
117    fn unavailable_error(&self) -> MemoryError {
118        let reason = self
119            .disabled_reason
120            .as_deref()
121            .unwrap_or("embedding backend unavailable");
122        MemoryError::Embedding(format!("embeddings disabled: {reason}"))
123    }
124
125    fn ensure_dimension(&self, embedding: &[f32]) -> MemoryResult<()> {
126        if embedding.len() != self.dimension {
127            return Err(MemoryError::Embedding(format!(
128                "embedding dimension mismatch: expected {}, got {}",
129                self.dimension,
130                embedding.len()
131            )));
132        }
133        Ok(())
134    }
135
136    /// Generate embeddings for a single text.
137    pub async fn embed(&self, text: &str) -> MemoryResult<Vec<f32>> {
138        let Some(model) = self.model.as_ref() else {
139            return Err(self.unavailable_error());
140        };
141
142        let mut embeddings = model
143            .embed(vec![text.to_string()], None)
144            .map_err(|e| MemoryError::Embedding(e.to_string()))?;
145        let embedding = embeddings
146            .pop()
147            .ok_or_else(|| MemoryError::Embedding("no embedding generated".to_string()))?;
148        self.ensure_dimension(&embedding)?;
149        Ok(embedding)
150    }
151
152    /// Generate embeddings for multiple texts.
153    pub async fn embed_batch(&self, texts: &[String]) -> MemoryResult<Vec<Vec<f32>>> {
154        let Some(model) = self.model.as_ref() else {
155            return Err(self.unavailable_error());
156        };
157
158        let embeddings = model
159            .embed(texts.to_vec(), None)
160            .map_err(|e| MemoryError::Embedding(e.to_string()))?;
161
162        for embedding in &embeddings {
163            self.ensure_dimension(embedding)?;
164        }
165
166        Ok(embeddings)
167    }
168
169    /// Calculate cosine similarity between two vectors.
170    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
171        if a.len() != b.len() {
172            return 0.0;
173        }
174
175        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
176        let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
177        let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
178
179        if magnitude_a == 0.0 || magnitude_b == 0.0 {
180            0.0
181        } else {
182            dot_product / (magnitude_a * magnitude_b)
183        }
184    }
185
186    /// Calculate Euclidean distance between two vectors.
187    pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
188        a.iter()
189            .zip(b.iter())
190            .map(|(x, y)| (x - y).powi(2))
191            .sum::<f32>()
192            .sqrt()
193    }
194}
195
196fn resolve_embedding_cache_dir() -> PathBuf {
197    if let Ok(explicit) = std::env::var("FASTEMBED_CACHE_DIR") {
198        let explicit_path = PathBuf::from(explicit);
199        if let Err(err) = std::fs::create_dir_all(&explicit_path) {
200            tracing::warn!(
201                target: "tandem.memory",
202                "Failed to create FASTEMBED_CACHE_DIR {:?}: {}",
203                explicit_path,
204                err
205            );
206        }
207        return explicit_path;
208    }
209
210    let base = dirs::data_local_dir()
211        .or_else(dirs::cache_dir)
212        .unwrap_or_else(std::env::temp_dir);
213    let cache_dir = base.join("tandem").join("fastembed");
214
215    if let Err(err) = std::fs::create_dir_all(&cache_dir) {
216        tracing::warn!(
217            target: "tandem.memory",
218            "Failed to create embedding cache directory {:?}: {}",
219            cache_dir,
220            err
221        );
222    }
223
224    cache_dir
225}
226
227impl Default for EmbeddingService {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233// Global embedding service instance.
234static EMBEDDING_SERVICE: OnceCell<Arc<Mutex<EmbeddingService>>> = OnceCell::new();
235
236/// Get or initialize the global embedding service.
237pub async fn get_embedding_service() -> Arc<Mutex<EmbeddingService>> {
238    EMBEDDING_SERVICE
239        .get_or_init(|| Arc::new(Mutex::new(EmbeddingService::new())))
240        .clone()
241}
242
243/// Initialize the embedding service with custom configuration.
244pub fn init_embedding_service(model_name: Option<String>, dimension: Option<usize>) {
245    let service = if let (Some(name), Some(dim)) = (model_name, dimension) {
246        EmbeddingService::with_model(name, dim)
247    } else {
248        EmbeddingService::new()
249    };
250
251    let _ = EMBEDDING_SERVICE.set(Arc::new(Mutex::new(service)));
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[tokio::test]
259    async fn test_embedding_dimension_or_unavailable() {
260        let service = EmbeddingService::new();
261
262        if !service.is_available() {
263            let err = service.embed("Hello world").await.unwrap_err();
264            assert!(err.to_string().contains("embeddings disabled"));
265            return;
266        }
267
268        let embedding = service.embed("Hello world").await.unwrap();
269        assert_eq!(embedding.len(), DEFAULT_EMBEDDING_DIMENSION);
270    }
271
272    #[tokio::test]
273    async fn test_embed_batch_or_unavailable() {
274        let service = EmbeddingService::new();
275        let texts = vec![
276            "First text".to_string(),
277            "Second text".to_string(),
278            "Third text".to_string(),
279        ];
280
281        let result = service.embed_batch(&texts).await;
282        if !service.is_available() {
283            assert!(result.is_err());
284            return;
285        }
286
287        let embeddings = result.unwrap();
288        assert_eq!(embeddings.len(), 3);
289        for emb in &embeddings {
290            assert_eq!(emb.len(), DEFAULT_EMBEDDING_DIMENSION);
291        }
292    }
293
294    #[test]
295    fn test_cosine_similarity() {
296        let a = vec![1.0f32, 0.0, 0.0];
297        let b = vec![1.0f32, 0.0, 0.0];
298        let c = vec![0.0f32, 1.0, 0.0];
299
300        let sim_same = EmbeddingService::cosine_similarity(&a, &b);
301        let sim_orthogonal = EmbeddingService::cosine_similarity(&a, &c);
302
303        assert!((sim_same - 1.0).abs() < 1e-6);
304        assert!(sim_orthogonal.abs() < 1e-6);
305    }
306}