tandem_memory/
embeddings.rs1use crate::types::{
5 MemoryError, MemoryResult, DEFAULT_EMBEDDING_DIMENSION, DEFAULT_EMBEDDING_MODEL,
6};
7use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
8use once_cell::sync::OnceCell;
9use std::path::PathBuf;
10use std::sync::Arc;
11use tokio::sync::Mutex;
12
13pub struct EmbeddingService {
15 model_name: String,
16 dimension: usize,
17 model: Option<TextEmbedding>,
18 disabled_reason: Option<String>,
19}
20
21impl EmbeddingService {
22 pub fn new() -> Self {
24 Self::with_model(
25 DEFAULT_EMBEDDING_MODEL.to_string(),
26 DEFAULT_EMBEDDING_DIMENSION,
27 )
28 }
29
30 pub fn with_model(model_name: String, dimension: usize) -> Self {
32 let (model, disabled_reason) = Self::init_model(&model_name);
33
34 if let Some(reason) = &disabled_reason {
35 tracing::warn!(
36 target: "tandem.memory",
37 "Embeddings disabled: model={} reason={}",
38 model_name,
39 reason
40 );
41 } else {
42 tracing::info!(
43 target: "tandem.memory",
44 "Embeddings enabled: model={} dimension={}",
45 model_name,
46 dimension
47 );
48 }
49
50 Self {
51 model_name,
52 dimension,
53 model,
54 disabled_reason,
55 }
56 }
57
58 fn init_model(model_name: &str) -> (Option<TextEmbedding>, Option<String>) {
59 let Some(parsed_model) = Self::parse_model_id(model_name) else {
60 return (
61 None,
62 Some(format!(
63 "unsupported embedding model id '{}'; supported: {}",
64 model_name, DEFAULT_EMBEDDING_MODEL
65 )),
66 );
67 };
68
69 let cache_dir = resolve_embedding_cache_dir();
70 let options = InitOptions::new(parsed_model).with_cache_dir(cache_dir.clone());
71
72 tracing::info!(
73 target: "tandem.memory",
74 "Initializing embeddings with cache dir: {}",
75 cache_dir.display()
76 );
77
78 match TextEmbedding::try_new(options) {
79 Ok(model) => (Some(model), None),
80 Err(err) => (
81 None,
82 Some(format!(
83 "failed to initialize embedding model '{}': {}",
84 model_name, err
85 )),
86 ),
87 }
88 }
89
90 fn parse_model_id(model_name: &str) -> Option<EmbeddingModel> {
91 match model_name.trim().to_ascii_lowercase().as_str() {
92 "all-minilm-l6-v2" | "all_minilm_l6_v2" => Some(EmbeddingModel::AllMiniLML6V2),
93 _ => None,
94 }
95 }
96
97 pub fn dimension(&self) -> usize {
99 self.dimension
100 }
101
102 pub fn model_name(&self) -> &str {
104 &self.model_name
105 }
106
107 pub fn is_available(&self) -> bool {
109 self.model.is_some()
110 }
111
112 pub fn disabled_reason(&self) -> Option<&str> {
114 self.disabled_reason.as_deref()
115 }
116
117 fn unavailable_error(&self) -> MemoryError {
118 let reason = self
119 .disabled_reason
120 .as_deref()
121 .unwrap_or("embedding backend unavailable");
122 MemoryError::Embedding(format!("embeddings disabled: {reason}"))
123 }
124
125 fn ensure_dimension(&self, embedding: &[f32]) -> MemoryResult<()> {
126 if embedding.len() != self.dimension {
127 return Err(MemoryError::Embedding(format!(
128 "embedding dimension mismatch: expected {}, got {}",
129 self.dimension,
130 embedding.len()
131 )));
132 }
133 Ok(())
134 }
135
136 pub async fn embed(&self, text: &str) -> MemoryResult<Vec<f32>> {
138 let Some(model) = self.model.as_ref() else {
139 return Err(self.unavailable_error());
140 };
141
142 let mut embeddings = model
143 .embed(vec![text.to_string()], None)
144 .map_err(|e| MemoryError::Embedding(e.to_string()))?;
145 let embedding = embeddings
146 .pop()
147 .ok_or_else(|| MemoryError::Embedding("no embedding generated".to_string()))?;
148 self.ensure_dimension(&embedding)?;
149 Ok(embedding)
150 }
151
152 pub async fn embed_batch(&self, texts: &[String]) -> MemoryResult<Vec<Vec<f32>>> {
154 let Some(model) = self.model.as_ref() else {
155 return Err(self.unavailable_error());
156 };
157
158 let embeddings = model
159 .embed(texts.to_vec(), None)
160 .map_err(|e| MemoryError::Embedding(e.to_string()))?;
161
162 for embedding in &embeddings {
163 self.ensure_dimension(embedding)?;
164 }
165
166 Ok(embeddings)
167 }
168
169 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
171 if a.len() != b.len() {
172 return 0.0;
173 }
174
175 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
176 let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
177 let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
178
179 if magnitude_a == 0.0 || magnitude_b == 0.0 {
180 0.0
181 } else {
182 dot_product / (magnitude_a * magnitude_b)
183 }
184 }
185
186 pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
188 a.iter()
189 .zip(b.iter())
190 .map(|(x, y)| (x - y).powi(2))
191 .sum::<f32>()
192 .sqrt()
193 }
194}
195
196fn resolve_embedding_cache_dir() -> PathBuf {
197 if let Ok(explicit) = std::env::var("FASTEMBED_CACHE_DIR") {
198 let explicit_path = PathBuf::from(explicit);
199 if let Err(err) = std::fs::create_dir_all(&explicit_path) {
200 tracing::warn!(
201 target: "tandem.memory",
202 "Failed to create FASTEMBED_CACHE_DIR {:?}: {}",
203 explicit_path,
204 err
205 );
206 }
207 return explicit_path;
208 }
209
210 let base = dirs::data_local_dir()
211 .or_else(dirs::cache_dir)
212 .unwrap_or_else(std::env::temp_dir);
213 let cache_dir = base.join("tandem").join("fastembed");
214
215 if let Err(err) = std::fs::create_dir_all(&cache_dir) {
216 tracing::warn!(
217 target: "tandem.memory",
218 "Failed to create embedding cache directory {:?}: {}",
219 cache_dir,
220 err
221 );
222 }
223
224 cache_dir
225}
226
227impl Default for EmbeddingService {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233static EMBEDDING_SERVICE: OnceCell<Arc<Mutex<EmbeddingService>>> = OnceCell::new();
235
236pub async fn get_embedding_service() -> Arc<Mutex<EmbeddingService>> {
238 EMBEDDING_SERVICE
239 .get_or_init(|| Arc::new(Mutex::new(EmbeddingService::new())))
240 .clone()
241}
242
243pub fn init_embedding_service(model_name: Option<String>, dimension: Option<usize>) {
245 let service = if let (Some(name), Some(dim)) = (model_name, dimension) {
246 EmbeddingService::with_model(name, dim)
247 } else {
248 EmbeddingService::new()
249 };
250
251 let _ = EMBEDDING_SERVICE.set(Arc::new(Mutex::new(service)));
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[tokio::test]
259 async fn test_embedding_dimension_or_unavailable() {
260 let service = EmbeddingService::new();
261
262 if !service.is_available() {
263 let err = service.embed("Hello world").await.unwrap_err();
264 assert!(err.to_string().contains("embeddings disabled"));
265 return;
266 }
267
268 let embedding = service.embed("Hello world").await.unwrap();
269 assert_eq!(embedding.len(), DEFAULT_EMBEDDING_DIMENSION);
270 }
271
272 #[tokio::test]
273 async fn test_embed_batch_or_unavailable() {
274 let service = EmbeddingService::new();
275 let texts = vec![
276 "First text".to_string(),
277 "Second text".to_string(),
278 "Third text".to_string(),
279 ];
280
281 let result = service.embed_batch(&texts).await;
282 if !service.is_available() {
283 assert!(result.is_err());
284 return;
285 }
286
287 let embeddings = result.unwrap();
288 assert_eq!(embeddings.len(), 3);
289 for emb in &embeddings {
290 assert_eq!(emb.len(), DEFAULT_EMBEDDING_DIMENSION);
291 }
292 }
293
294 #[test]
295 fn test_cosine_similarity() {
296 let a = vec![1.0f32, 0.0, 0.0];
297 let b = vec![1.0f32, 0.0, 0.0];
298 let c = vec![0.0f32, 1.0, 0.0];
299
300 let sim_same = EmbeddingService::cosine_similarity(&a, &b);
301 let sim_orthogonal = EmbeddingService::cosine_similarity(&a, &c);
302
303 assert!((sim_same - 1.0).abs() < 1e-6);
304 assert!(sim_orthogonal.abs() < 1e-6);
305 }
306}