Skip to main content

starpod_memory/
embedder.rs

1//! Embedding support for vector search.
2//!
3//! Provides the [`Embedder`] trait for pluggable text embedding models and a
4//! concrete [`LocalEmbedder`] (behind the `embeddings` feature) that uses
5//! [fastembed](https://docs.rs/fastembed) with the BGE-Small-EN v1.5 model
6//! (384 dimensions, ~45 MB on disk).
7//!
8//! Also provides [`cosine_similarity`] for comparing embedding vectors.
9
10use starpod_core::Result;
11#[cfg(feature = "embeddings")]
12use starpod_core::StarpodError;
13
14/// Trait for text embedding models.
15///
16/// Implementations must be `Send + Sync` to allow sharing across async tasks
17/// via `Arc<dyn Embedder>`.
18#[async_trait::async_trait]
19pub trait Embedder: Send + Sync {
20    /// Embed one or more texts into fixed-dimensional vectors.
21    ///
22    /// Returns one vector per input text. All vectors have the same
23    /// dimensionality (see [`dimensions`](Self::dimensions)).
24    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
25
26    /// Dimensionality of the output vectors (e.g., 384 for BGE-Small-EN).
27    fn dimensions(&self) -> usize;
28}
29
30/// Local embedder using fastembed (BGE-Small-EN v1.5, 384 dims).
31///
32/// The model is lazily initialized on the first call to [`embed`](Embedder::embed),
33/// which downloads the model weights (~45 MB) if not already cached.
34///
35/// Thread-safe: the inner model is protected by a `Mutex` and the struct
36/// implements `Send + Sync` via the `Embedder` trait.
37#[cfg(feature = "embeddings")]
38pub struct LocalEmbedder {
39    model: std::sync::Mutex<Option<fastembed::TextEmbedding>>,
40}
41
42#[cfg(feature = "embeddings")]
43impl LocalEmbedder {
44    /// Create a new `LocalEmbedder`. The underlying model is loaded lazily.
45    pub fn new() -> Self {
46        Self {
47            model: std::sync::Mutex::new(None),
48        }
49    }
50
51    /// Get or initialize the fastembed model.
52    fn get_or_init(&self) -> Result<std::sync::MutexGuard<'_, Option<fastembed::TextEmbedding>>> {
53        let mut guard = self.model.lock().map_err(|e| {
54            StarpodError::Agent(format!("Embedder lock poisoned: {}", e))
55        })?;
56        if guard.is_none() {
57            let model = fastembed::TextEmbedding::try_new(
58                fastembed::InitOptions::new(fastembed::EmbeddingModel::BGESmallENV15)
59                    .with_show_download_progress(false),
60            )
61            .map_err(|e| StarpodError::Agent(format!("Failed to init embedding model: {}", e)))?;
62            *guard = Some(model);
63        }
64        Ok(guard)
65    }
66}
67
68#[cfg(feature = "embeddings")]
69impl Default for LocalEmbedder {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75#[cfg(feature = "embeddings")]
76#[async_trait::async_trait]
77impl Embedder for LocalEmbedder {
78    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
79        let guard = self.get_or_init()?;
80        let model = guard.as_ref().unwrap();
81        let results = model
82            .embed(texts.to_vec(), None)
83            .map_err(|e| StarpodError::Agent(format!("Embedding failed: {}", e)))?;
84        Ok(results)
85    }
86
87    fn dimensions(&self) -> usize {
88        384
89    }
90}
91
92/// Compute cosine similarity between two vectors.
93///
94/// Returns a value in `[-1.0, 1.0]`:
95/// - `1.0` = identical direction
96/// - `0.0` = orthogonal (unrelated)
97/// - `-1.0` = opposite direction
98///
99/// If either vector is zero-length, returns `0.0`.
100///
101/// Only the overlapping dimensions are considered (i.e., `min(a.len(), b.len())`
102/// pairs are used via `zip`). In practice both vectors should have the same
103/// dimensionality.
104pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
105    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
106    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
107    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
108    if norm_a == 0.0 || norm_b == 0.0 {
109        return 0.0;
110    }
111    dot / (norm_a * norm_b)
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn cosine_identical_vectors() {
120        let v = vec![1.0, 2.0, 3.0];
121        let sim = cosine_similarity(&v, &v);
122        assert!((sim - 1.0).abs() < 1e-6);
123    }
124
125    #[test]
126    fn cosine_orthogonal_vectors() {
127        let a = vec![1.0, 0.0];
128        let b = vec![0.0, 1.0];
129        let sim = cosine_similarity(&a, &b);
130        assert!(sim.abs() < 1e-6);
131    }
132
133    #[test]
134    fn cosine_opposite_vectors() {
135        let a = vec![1.0, 0.0];
136        let b = vec![-1.0, 0.0];
137        let sim = cosine_similarity(&a, &b);
138        assert!((sim - (-1.0)).abs() < 1e-6);
139    }
140
141    #[test]
142    fn cosine_zero_vector() {
143        let a = vec![1.0, 2.0];
144        let b = vec![0.0, 0.0];
145        assert_eq!(cosine_similarity(&a, &b), 0.0);
146    }
147
148    #[test]
149    fn cosine_both_zero_vectors() {
150        let a = vec![0.0, 0.0];
151        let b = vec![0.0, 0.0];
152        assert_eq!(cosine_similarity(&a, &b), 0.0);
153    }
154
155    #[test]
156    fn cosine_high_dimensional() {
157        // 384-dim vectors (same as BGE-Small-EN) — identical direction
158        let a: Vec<f32> = (0..384).map(|i| (i as f32).sin()).collect();
159        let b = a.clone();
160        let sim = cosine_similarity(&a, &b);
161        assert!((sim - 1.0).abs() < 1e-5, "Identical 384-dim vectors should have sim ~1.0, got {}", sim);
162    }
163
164    #[test]
165    fn cosine_different_lengths_uses_shorter() {
166        // zip truncates to shorter length — [1,0] . [1] = 1 / (1 * 1) = 1.0
167        let a = vec![1.0, 0.0, 0.0];
168        let b = vec![1.0];
169        let sim = cosine_similarity(&a, &b);
170        // dot = 1, norm_a = sqrt(1+0+0) = 1, norm_b = 1
171        assert!((sim - 1.0).abs() < 1e-6);
172    }
173
174    #[test]
175    fn cosine_scaled_vectors_are_equal() {
176        let a = vec![1.0, 2.0, 3.0];
177        let b = vec![2.0, 4.0, 6.0]; // same direction, 2x magnitude
178        let sim = cosine_similarity(&a, &b);
179        assert!((sim - 1.0).abs() < 1e-6, "Scaled vectors should have similarity 1.0");
180    }
181}