starpod_memory/
embedder.rs1use starpod_core::Result;
11#[cfg(feature = "embeddings")]
12use starpod_core::StarpodError;
13
14#[async_trait::async_trait]
19pub trait Embedder: Send + Sync {
20 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
25
26 fn dimensions(&self) -> usize;
28}
29
30#[cfg(feature = "embeddings")]
38pub struct LocalEmbedder {
39 model: std::sync::Mutex<Option<fastembed::TextEmbedding>>,
40}
41
42#[cfg(feature = "embeddings")]
43impl LocalEmbedder {
44 pub fn new() -> Self {
46 Self {
47 model: std::sync::Mutex::new(None),
48 }
49 }
50
51 fn get_or_init(&self) -> Result<std::sync::MutexGuard<'_, Option<fastembed::TextEmbedding>>> {
53 let mut guard = self
54 .model
55 .lock()
56 .map_err(|e| StarpodError::Agent(format!("Embedder lock poisoned: {}", e)))?;
57 if guard.is_none() {
58 let model = fastembed::TextEmbedding::try_new(
59 fastembed::InitOptions::new(fastembed::EmbeddingModel::BGESmallENV15)
60 .with_show_download_progress(false),
61 )
62 .map_err(|e| StarpodError::Agent(format!("Failed to init embedding model: {}", e)))?;
63 *guard = Some(model);
64 }
65 Ok(guard)
66 }
67}
68
69#[cfg(feature = "embeddings")]
70impl Default for LocalEmbedder {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76#[cfg(feature = "embeddings")]
77#[async_trait::async_trait]
78impl Embedder for LocalEmbedder {
79 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
80 let guard = self.get_or_init()?;
81 let model = guard.as_ref().unwrap();
82 let results = model
83 .embed(texts.to_vec(), None)
84 .map_err(|e| StarpodError::Agent(format!("Embedding failed: {}", e)))?;
85 Ok(results)
86 }
87
88 fn dimensions(&self) -> usize {
89 384
90 }
91}
92
93pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
106 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
107 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
108 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
109 if norm_a == 0.0 || norm_b == 0.0 {
110 return 0.0;
111 }
112 dot / (norm_a * norm_b)
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn cosine_identical_vectors() {
121 let v = vec![1.0, 2.0, 3.0];
122 let sim = cosine_similarity(&v, &v);
123 assert!((sim - 1.0).abs() < 1e-6);
124 }
125
126 #[test]
127 fn cosine_orthogonal_vectors() {
128 let a = vec![1.0, 0.0];
129 let b = vec![0.0, 1.0];
130 let sim = cosine_similarity(&a, &b);
131 assert!(sim.abs() < 1e-6);
132 }
133
134 #[test]
135 fn cosine_opposite_vectors() {
136 let a = vec![1.0, 0.0];
137 let b = vec![-1.0, 0.0];
138 let sim = cosine_similarity(&a, &b);
139 assert!((sim - (-1.0)).abs() < 1e-6);
140 }
141
142 #[test]
143 fn cosine_zero_vector() {
144 let a = vec![1.0, 2.0];
145 let b = vec![0.0, 0.0];
146 assert_eq!(cosine_similarity(&a, &b), 0.0);
147 }
148
149 #[test]
150 fn cosine_both_zero_vectors() {
151 let a = vec![0.0, 0.0];
152 let b = vec![0.0, 0.0];
153 assert_eq!(cosine_similarity(&a, &b), 0.0);
154 }
155
156 #[test]
157 fn cosine_high_dimensional() {
158 let a: Vec<f32> = (0..384).map(|i| (i as f32).sin()).collect();
160 let b = a.clone();
161 let sim = cosine_similarity(&a, &b);
162 assert!(
163 (sim - 1.0).abs() < 1e-5,
164 "Identical 384-dim vectors should have sim ~1.0, got {}",
165 sim
166 );
167 }
168
169 #[test]
170 fn cosine_different_lengths_uses_shorter() {
171 let a = vec![1.0, 0.0, 0.0];
173 let b = vec![1.0];
174 let sim = cosine_similarity(&a, &b);
175 assert!((sim - 1.0).abs() < 1e-6);
177 }
178
179 #[test]
180 fn cosine_scaled_vectors_are_equal() {
181 let a = vec![1.0, 2.0, 3.0];
182 let b = vec![2.0, 4.0, 6.0]; let sim = cosine_similarity(&a, &b);
184 assert!(
185 (sim - 1.0).abs() < 1e-6,
186 "Scaled vectors should have similarity 1.0"
187 );
188 }
189}