starpod_memory/
embedder.rs1use starpod_core::Result;
11#[cfg(feature = "embeddings")]
12use starpod_core::StarpodError;
13
14#[async_trait::async_trait]
19pub trait Embedder: Send + Sync {
20 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
25
26 fn dimensions(&self) -> usize;
28}
29
30#[cfg(feature = "embeddings")]
38pub struct LocalEmbedder {
39 model: std::sync::Mutex<Option<fastembed::TextEmbedding>>,
40}
41
42#[cfg(feature = "embeddings")]
43impl LocalEmbedder {
44 pub fn new() -> Self {
46 Self {
47 model: std::sync::Mutex::new(None),
48 }
49 }
50
51 fn get_or_init(&self) -> Result<std::sync::MutexGuard<'_, Option<fastembed::TextEmbedding>>> {
53 let mut guard = self.model.lock().map_err(|e| {
54 StarpodError::Agent(format!("Embedder lock poisoned: {}", e))
55 })?;
56 if guard.is_none() {
57 let model = fastembed::TextEmbedding::try_new(
58 fastembed::InitOptions::new(fastembed::EmbeddingModel::BGESmallENV15)
59 .with_show_download_progress(false),
60 )
61 .map_err(|e| StarpodError::Agent(format!("Failed to init embedding model: {}", e)))?;
62 *guard = Some(model);
63 }
64 Ok(guard)
65 }
66}
67
68#[cfg(feature = "embeddings")]
69impl Default for LocalEmbedder {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75#[cfg(feature = "embeddings")]
76#[async_trait::async_trait]
77impl Embedder for LocalEmbedder {
78 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
79 let guard = self.get_or_init()?;
80 let model = guard.as_ref().unwrap();
81 let results = model
82 .embed(texts.to_vec(), None)
83 .map_err(|e| StarpodError::Agent(format!("Embedding failed: {}", e)))?;
84 Ok(results)
85 }
86
87 fn dimensions(&self) -> usize {
88 384
89 }
90}
91
92pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
105 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
106 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
107 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
108 if norm_a == 0.0 || norm_b == 0.0 {
109 return 0.0;
110 }
111 dot / (norm_a * norm_b)
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn cosine_identical_vectors() {
120 let v = vec![1.0, 2.0, 3.0];
121 let sim = cosine_similarity(&v, &v);
122 assert!((sim - 1.0).abs() < 1e-6);
123 }
124
125 #[test]
126 fn cosine_orthogonal_vectors() {
127 let a = vec![1.0, 0.0];
128 let b = vec![0.0, 1.0];
129 let sim = cosine_similarity(&a, &b);
130 assert!(sim.abs() < 1e-6);
131 }
132
133 #[test]
134 fn cosine_opposite_vectors() {
135 let a = vec![1.0, 0.0];
136 let b = vec![-1.0, 0.0];
137 let sim = cosine_similarity(&a, &b);
138 assert!((sim - (-1.0)).abs() < 1e-6);
139 }
140
141 #[test]
142 fn cosine_zero_vector() {
143 let a = vec![1.0, 2.0];
144 let b = vec![0.0, 0.0];
145 assert_eq!(cosine_similarity(&a, &b), 0.0);
146 }
147
148 #[test]
149 fn cosine_both_zero_vectors() {
150 let a = vec![0.0, 0.0];
151 let b = vec![0.0, 0.0];
152 assert_eq!(cosine_similarity(&a, &b), 0.0);
153 }
154
155 #[test]
156 fn cosine_high_dimensional() {
157 let a: Vec<f32> = (0..384).map(|i| (i as f32).sin()).collect();
159 let b = a.clone();
160 let sim = cosine_similarity(&a, &b);
161 assert!((sim - 1.0).abs() < 1e-5, "Identical 384-dim vectors should have sim ~1.0, got {}", sim);
162 }
163
164 #[test]
165 fn cosine_different_lengths_uses_shorter() {
166 let a = vec![1.0, 0.0, 0.0];
168 let b = vec![1.0];
169 let sim = cosine_similarity(&a, &b);
170 assert!((sim - 1.0).abs() < 1e-6);
172 }
173
174 #[test]
175 fn cosine_scaled_vectors_are_equal() {
176 let a = vec![1.0, 2.0, 3.0];
177 let b = vec![2.0, 4.0, 6.0]; let sim = cosine_similarity(&a, &b);
179 assert!((sim - 1.0).abs() < 1e-6, "Scaled vectors should have similarity 1.0");
180 }
181}