Skip to main content

semantic_memory/
embedder.rs

1//! Embedding trait and implementations.
2//!
3//! Provides the [`Embedder`] trait for text-to-vector conversion,
4//! with [`OllamaEmbedder`] (production) and [`MockEmbedder`] (testing).
5
6use crate::config::EmbeddingConfig;
7use crate::error::MemoryError;
8use std::future::Future;
9use std::hash::{Hash, Hasher};
10use std::pin::Pin;
11
12/// Boxed future type alias for single embedding results.
13pub type EmbedFuture<'a> = Pin<Box<dyn Future<Output = Result<Vec<f32>, MemoryError>> + Send + 'a>>;
14
15/// Boxed future type alias for batch embedding results.
16pub type EmbedBatchFuture<'a> =
17    Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, MemoryError>> + Send + 'a>>;
18
19/// Trait for embedding text into vectors.
20///
21/// Implement this to swap embedding providers.
22pub trait Embedder: Send + Sync {
23    /// Embed a single text. Returns a vector of f32.
24    fn embed<'a>(&'a self, text: &'a str) -> EmbedFuture<'a>;
25
26    /// Embed multiple texts in a batch.
27    ///
28    /// Takes owned strings to avoid lifetime issues across async boundaries.
29    fn embed_batch<'a>(&'a self, texts: Vec<String>) -> EmbedBatchFuture<'a>;
30
31    /// The model name this embedder uses.
32    fn model_name(&self) -> &str;
33
34    /// Expected embedding dimensions.
35    fn dimensions(&self) -> usize;
36}
37
38// ─── OllamaEmbedder ─────────────────────────────────────────────
39
40/// Embedding provider that calls Ollama's `/api/embed` endpoint.
41pub struct OllamaEmbedder {
42    client: reqwest::Client,
43    base_url: String,
44    model: String,
45    dimensions: usize,
46    batch_size: usize,
47}
48
49impl OllamaEmbedder {
50    /// Create a new OllamaEmbedder from config.
51    pub fn new(config: &EmbeddingConfig) -> Self {
52        let client = reqwest::Client::builder()
53            .timeout(std::time::Duration::from_secs(config.timeout_secs))
54            .build()
55            .expect("Failed to build reqwest client");
56
57        Self {
58            client,
59            base_url: config.ollama_url.trim_end_matches('/').to_string(),
60            model: config.model.clone(),
61            dimensions: config.dimensions,
62            batch_size: config.batch_size,
63        }
64    }
65}
66
67impl Embedder for OllamaEmbedder {
68    fn embed<'a>(&'a self, text: &'a str) -> EmbedFuture<'a> {
69        Box::pin(async move {
70            let mut results = self.embed_batch(vec![text.to_string()]).await?;
71            results.pop().ok_or_else(|| {
72                MemoryError::Other("Ollama returned empty embeddings for single text".to_string())
73            })
74        })
75    }
76
77    fn embed_batch<'a>(&'a self, texts: Vec<String>) -> EmbedBatchFuture<'a> {
78        Box::pin(async move {
79            let mut all_embeddings = Vec::with_capacity(texts.len());
80
81            for batch in texts.chunks(self.batch_size) {
82                let input: Vec<&str> = batch.iter().map(|s| s.as_str()).collect();
83                let body = serde_json::json!({
84                    "model": self.model,
85                    "input": input
86                });
87
88                let url = format!("{}/api/embed", self.base_url);
89                let response = self
90                    .client
91                    .post(&url)
92                    .json(&body)
93                    .send()
94                    .await
95                    .map_err(|e| {
96                        if e.is_connect() {
97                            MemoryError::EmbedderUnavailable(format!(
98                                "Ollama not running at {}",
99                                self.base_url
100                            ))
101                        } else if e.is_timeout() {
102                            MemoryError::EmbedderUnavailable(format!(
103                                "Ollama embedding timed out: {}",
104                                e
105                            ))
106                        } else {
107                            MemoryError::EmbeddingRequest(e)
108                        }
109                    })?;
110
111                if response.status() == reqwest::StatusCode::NOT_FOUND {
112                    return Err(MemoryError::EmbedderUnavailable(format!(
113                        "Model '{}' not available in Ollama. Run: ollama pull {}",
114                        self.model, self.model
115                    )));
116                }
117
118                if !response.status().is_success() {
119                    let status = response.status();
120                    let body = response.text().await.unwrap_or_default();
121                    return Err(MemoryError::Other(format!(
122                        "Ollama returned HTTP {}: {}",
123                        status,
124                        &body[..body.len().min(500)]
125                    )));
126                }
127
128                let resp_body: serde_json::Value = response.json().await?;
129                let batch_embeddings = parse_embedding_response(&resp_body, self.dimensions)?;
130                all_embeddings.extend(batch_embeddings);
131            }
132
133            Ok(all_embeddings)
134        })
135    }
136
137    fn model_name(&self) -> &str {
138        &self.model
139    }
140
141    fn dimensions(&self) -> usize {
142        self.dimensions
143    }
144}
145
146/// Parse an Ollama embedding response body into vectors.
147///
148/// Validates that all values are numeric and dimensions match.
149#[doc(hidden)]
150pub fn parse_embedding_response(
151    body: &serde_json::Value,
152    expected_dims: usize,
153) -> Result<Vec<Vec<f32>>, MemoryError> {
154    let embeddings = body["embeddings"].as_array().ok_or_else(|| {
155        MemoryError::Other("Ollama response missing 'embeddings' field".to_string())
156    })?;
157
158    let mut result = Vec::with_capacity(embeddings.len());
159    for embedding_val in embeddings {
160        let raw_array = embedding_val
161            .as_array()
162            .ok_or_else(|| MemoryError::Other("Embedding is not an array".to_string()))?;
163
164        let mut embedding = Vec::with_capacity(raw_array.len());
165        for (i, v) in raw_array.iter().enumerate() {
166            let val = v.as_f64().ok_or_else(|| {
167                MemoryError::Other(format!(
168                    "Embedding dimension {} contains non-numeric value: {}",
169                    i, v
170                ))
171            })?;
172            embedding.push(val as f32);
173        }
174
175        if embedding.len() != expected_dims {
176            return Err(MemoryError::DimensionMismatch {
177                expected: expected_dims,
178                actual: embedding.len(),
179            });
180        }
181
182        result.push(embedding);
183    }
184
185    Ok(result)
186}
187
188// ─── MockEmbedder ────────────────────────────────────────────────
189
190/// Deterministic embedder for unit tests.
191///
192/// Generates consistent embeddings based on a hash of the input text.
193/// Same text always produces the same embedding. Output is normalized.
194pub struct MockEmbedder {
195    dimensions: usize,
196}
197
198impl MockEmbedder {
199    /// Create a new MockEmbedder with the given dimensions.
200    pub fn new(dimensions: usize) -> Self {
201        Self { dimensions }
202    }
203}
204
205impl Embedder for MockEmbedder {
206    fn embed<'a>(&'a self, text: &'a str) -> EmbedFuture<'a> {
207        let embedding = deterministic_embedding(text, self.dimensions);
208        Box::pin(async move { Ok(embedding) })
209    }
210
211    fn embed_batch<'a>(&'a self, texts: Vec<String>) -> EmbedBatchFuture<'a> {
212        let embeddings: Vec<Vec<f32>> = texts
213            .iter()
214            .map(|t| deterministic_embedding(t, self.dimensions))
215            .collect();
216        Box::pin(async move { Ok(embeddings) })
217    }
218
219    fn model_name(&self) -> &str {
220        "mock-embedder"
221    }
222
223    fn dimensions(&self) -> usize {
224        self.dimensions
225    }
226}
227
228/// Generate a deterministic embedding from text using a hash-seeded xorshift RNG.
229fn deterministic_embedding(text: &str, dimensions: usize) -> Vec<f32> {
230    let mut hasher = std::hash::DefaultHasher::new();
231    text.hash(&mut hasher);
232    let mut state = hasher.finish();
233    if state == 0 {
234        state = 1;
235    }
236
237    let mut values = Vec::with_capacity(dimensions);
238    for _ in 0..dimensions {
239        // xorshift64
240        state ^= state << 13;
241        state ^= state >> 7;
242        state ^= state << 17;
243        let val = ((state as f64) / (u64::MAX as f64)) * 2.0 - 1.0;
244        values.push(val as f32);
245    }
246
247    // Normalize to unit length
248    let magnitude: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
249    if magnitude > 0.0 {
250        for v in &mut values {
251            *v /= magnitude;
252        }
253    }
254
255    values
256}