Skip to main content

starpod_memory/
embedder.rs

1//! Embedding support for vector search.
2//!
3//! Provides the [`Embedder`] trait for pluggable text embedding models and a
4//! concrete [`LocalEmbedder`] (behind the `embeddings` feature) that uses
5//! [fastembed](https://docs.rs/fastembed) with the BGE-Small-EN v1.5 model
6//! (384 dimensions, ~45 MB on disk).
7//!
8//! Also provides [`cosine_similarity`] for comparing embedding vectors.
9
10use starpod_core::Result;
11#[cfg(feature = "embeddings")]
12use starpod_core::StarpodError;
13
14/// Trait for text embedding models.
15///
16/// Implementations must be `Send + Sync` to allow sharing across async tasks
17/// via `Arc<dyn Embedder>`.
18#[async_trait::async_trait]
19pub trait Embedder: Send + Sync {
20    /// Embed one or more texts into fixed-dimensional vectors.
21    ///
22    /// Returns one vector per input text. All vectors have the same
23    /// dimensionality (see [`dimensions`](Self::dimensions)).
24    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
25
26    /// Dimensionality of the output vectors (e.g., 384 for BGE-Small-EN).
27    fn dimensions(&self) -> usize;
28}
29
30/// Local embedder using fastembed (BGE-Small-EN v1.5, 384 dims).
31///
32/// The model is lazily initialized on the first call to [`embed`](Embedder::embed),
33/// which downloads the model weights (~45 MB) if not already cached.
34///
35/// Thread-safe: the inner model is protected by a `Mutex` and the struct
36/// implements `Send + Sync` via the `Embedder` trait.
37#[cfg(feature = "embeddings")]
38pub struct LocalEmbedder {
39    model: std::sync::Mutex<Option<fastembed::TextEmbedding>>,
40}
41
42#[cfg(feature = "embeddings")]
43impl LocalEmbedder {
44    /// Create a new `LocalEmbedder`. The underlying model is loaded lazily.
45    pub fn new() -> Self {
46        Self {
47            model: std::sync::Mutex::new(None),
48        }
49    }
50
51    /// Get or initialize the fastembed model.
52    fn get_or_init(&self) -> Result<std::sync::MutexGuard<'_, Option<fastembed::TextEmbedding>>> {
53        let mut guard = self
54            .model
55            .lock()
56            .map_err(|e| StarpodError::Agent(format!("Embedder lock poisoned: {}", e)))?;
57        if guard.is_none() {
58            let model = fastembed::TextEmbedding::try_new(
59                fastembed::InitOptions::new(fastembed::EmbeddingModel::BGESmallENV15)
60                    .with_show_download_progress(false),
61            )
62            .map_err(|e| StarpodError::Agent(format!("Failed to init embedding model: {}", e)))?;
63            *guard = Some(model);
64        }
65        Ok(guard)
66    }
67}
68
69#[cfg(feature = "embeddings")]
70impl Default for LocalEmbedder {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76#[cfg(feature = "embeddings")]
77#[async_trait::async_trait]
78impl Embedder for LocalEmbedder {
79    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
80        let guard = self.get_or_init()?;
81        let model = guard.as_ref().unwrap();
82        let results = model
83            .embed(texts.to_vec(), None)
84            .map_err(|e| StarpodError::Agent(format!("Embedding failed: {}", e)))?;
85        Ok(results)
86    }
87
88    fn dimensions(&self) -> usize {
89        384
90    }
91}
92
93/// Compute cosine similarity between two vectors.
94///
95/// Returns a value in `[-1.0, 1.0]`:
96/// - `1.0` = identical direction
97/// - `0.0` = orthogonal (unrelated)
98/// - `-1.0` = opposite direction
99///
100/// If either vector is zero-length, returns `0.0`.
101///
102/// Only the overlapping dimensions are considered (i.e., `min(a.len(), b.len())`
103/// pairs are used via `zip`). In practice both vectors should have the same
104/// dimensionality.
105pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
106    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
107    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
108    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
109    if norm_a == 0.0 || norm_b == 0.0 {
110        return 0.0;
111    }
112    dot / (norm_a * norm_b)
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn cosine_identical_vectors() {
121        let v = vec![1.0, 2.0, 3.0];
122        let sim = cosine_similarity(&v, &v);
123        assert!((sim - 1.0).abs() < 1e-6);
124    }
125
126    #[test]
127    fn cosine_orthogonal_vectors() {
128        let a = vec![1.0, 0.0];
129        let b = vec![0.0, 1.0];
130        let sim = cosine_similarity(&a, &b);
131        assert!(sim.abs() < 1e-6);
132    }
133
134    #[test]
135    fn cosine_opposite_vectors() {
136        let a = vec![1.0, 0.0];
137        let b = vec![-1.0, 0.0];
138        let sim = cosine_similarity(&a, &b);
139        assert!((sim - (-1.0)).abs() < 1e-6);
140    }
141
142    #[test]
143    fn cosine_zero_vector() {
144        let a = vec![1.0, 2.0];
145        let b = vec![0.0, 0.0];
146        assert_eq!(cosine_similarity(&a, &b), 0.0);
147    }
148
149    #[test]
150    fn cosine_both_zero_vectors() {
151        let a = vec![0.0, 0.0];
152        let b = vec![0.0, 0.0];
153        assert_eq!(cosine_similarity(&a, &b), 0.0);
154    }
155
156    #[test]
157    fn cosine_high_dimensional() {
158        // 384-dim vectors (same as BGE-Small-EN) — identical direction
159        let a: Vec<f32> = (0..384).map(|i| (i as f32).sin()).collect();
160        let b = a.clone();
161        let sim = cosine_similarity(&a, &b);
162        assert!(
163            (sim - 1.0).abs() < 1e-5,
164            "Identical 384-dim vectors should have sim ~1.0, got {}",
165            sim
166        );
167    }
168
169    #[test]
170    fn cosine_different_lengths_uses_shorter() {
171        // zip truncates to shorter length — [1,0] . [1] = 1 / (1 * 1) = 1.0
172        let a = vec![1.0, 0.0, 0.0];
173        let b = vec![1.0];
174        let sim = cosine_similarity(&a, &b);
175        // dot = 1, norm_a = sqrt(1+0+0) = 1, norm_b = 1
176        assert!((sim - 1.0).abs() < 1e-6);
177    }
178
179    #[test]
180    fn cosine_scaled_vectors_are_equal() {
181        let a = vec![1.0, 2.0, 3.0];
182        let b = vec![2.0, 4.0, 6.0]; // same direction, 2x magnitude
183        let sim = cosine_similarity(&a, &b);
184        assert!(
185            (sim - 1.0).abs() < 1e-6,
186            "Scaled vectors should have similarity 1.0"
187        );
188    }
189}