Skip to main content

semantic_memory/
embedder.rs

1//! Embedding trait and implementations.
2//!
3//! Provides the [`Embedder`] trait for text-to-vector conversion,
4//! with [`OllamaEmbedder`] (production) and [`MockEmbedder`] (testing).
5
6use crate::config::EmbeddingConfig;
7use crate::error::MemoryError;
8use std::future::Future;
9use std::hash::{Hash, Hasher};
10use std::pin::Pin;
11
12/// Boxed future type alias for single embedding results.
13pub type EmbedFuture<'a> = Pin<Box<dyn Future<Output = Result<Vec<f32>, MemoryError>> + Send + 'a>>;
14
15/// Boxed future type alias for batch embedding results.
16pub type EmbedBatchFuture<'a> =
17    Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, MemoryError>> + Send + 'a>>;
18
19/// Trait for embedding text into vectors.
20///
21/// Implement this to swap embedding providers.
22pub trait Embedder: Send + Sync {
23    /// Embed a single text. Returns a vector of f32.
24    fn embed<'a>(&'a self, text: &'a str) -> EmbedFuture<'a>;
25
26    /// Embed multiple texts in a batch.
27    ///
28    /// Takes owned strings to avoid lifetime issues across async boundaries.
29    fn embed_batch<'a>(&'a self, texts: Vec<String>) -> EmbedBatchFuture<'a>;
30
31    /// The model name this embedder uses.
32    fn model_name(&self) -> &str;
33
34    /// Expected embedding dimensions.
35    fn dimensions(&self) -> usize;
36}
37
38// ─── OllamaEmbedder ─────────────────────────────────────────────
39
40/// Embedding provider that calls Ollama's `/api/embed` endpoint.
41pub struct OllamaEmbedder {
42    client: reqwest::Client,
43    base_url: String,
44    model: String,
45    dimensions: usize,
46    batch_size: usize,
47}
48
49impl OllamaEmbedder {
50    /// Create a new OllamaEmbedder from config.
51    ///
52    /// Returns an error if the HTTP client cannot be constructed (e.g. TLS backend
53    /// is unavailable).
54    pub fn try_new(config: &EmbeddingConfig) -> Result<Self, MemoryError> {
55        let client = reqwest::Client::builder()
56            .timeout(std::time::Duration::from_secs(config.timeout_secs))
57            .build()
58            .map_err(|e| {
59                MemoryError::EmbedderUnavailable(format!("failed to build HTTP client: {e}"))
60            })?;
61
62        Ok(Self {
63            client,
64            base_url: config.ollama_url.trim_end_matches('/').to_string(),
65            model: config.model.clone(),
66            dimensions: config.dimensions,
67            batch_size: config.batch_size,
68        })
69    }
70
71    // GOV-005: Deprecated `new()` method removed — all consumers should use `try_new()`.
72}
73
74impl Embedder for OllamaEmbedder {
75    fn embed<'a>(&'a self, text: &'a str) -> EmbedFuture<'a> {
76        Box::pin(async move {
77            let mut results = self.embed_batch(vec![text.to_string()]).await?;
78            results.pop().ok_or_else(|| {
79                MemoryError::Other("Ollama returned empty embeddings for single text".to_string())
80            })
81        })
82    }
83
84    fn embed_batch<'a>(&'a self, texts: Vec<String>) -> EmbedBatchFuture<'a> {
85        Box::pin(async move {
86            let mut all_embeddings = Vec::with_capacity(texts.len());
87
88            for batch in texts.chunks(self.batch_size) {
89                let input: Vec<&str> = batch.iter().map(|s| s.as_str()).collect();
90                let body = serde_json::json!({
91                    "model": self.model,
92                    "input": input
93                });
94
95                let url = format!("{}/api/embed", self.base_url);
96                let response = self
97                    .client
98                    .post(&url)
99                    .json(&body)
100                    .send()
101                    .await
102                    .map_err(|e| {
103                        if e.is_connect() {
104                            MemoryError::EmbedderUnavailable(format!(
105                                "Ollama not running at {}",
106                                self.base_url
107                            ))
108                        } else if e.is_timeout() {
109                            MemoryError::EmbedderUnavailable(format!(
110                                "Ollama embedding timed out: {}",
111                                e
112                            ))
113                        } else {
114                            MemoryError::EmbeddingRequest(e)
115                        }
116                    })?;
117
118                if response.status() == reqwest::StatusCode::NOT_FOUND {
119                    return Err(MemoryError::EmbedderUnavailable(format!(
120                        "Model '{}' not available in Ollama. Run: ollama pull {}",
121                        self.model, self.model
122                    )));
123                }
124
125                if !response.status().is_success() {
126                    let status = response.status();
127                    let body = response
128                        .text()
129                        .await
130                        .map_err(|err| format!("failed to read Ollama error body: {err}"));
131                    return Err(format_ollama_http_error(status, body));
132                }
133
134                let resp_body: serde_json::Value = response.json().await?;
135                let batch_embeddings = parse_embedding_response(&resp_body, self.dimensions)?;
136                all_embeddings.extend(batch_embeddings);
137            }
138
139            Ok(all_embeddings)
140        })
141    }
142
143    fn model_name(&self) -> &str {
144        &self.model
145    }
146
147    fn dimensions(&self) -> usize {
148        self.dimensions
149    }
150}
151
152#[doc(hidden)]
153pub fn format_ollama_http_error(
154    status: reqwest::StatusCode,
155    body: Result<String, String>,
156) -> MemoryError {
157    match body {
158        Ok(body) => MemoryError::Other(format!(
159            "Ollama returned HTTP {}: {}",
160            status,
161            &body[..body.len().min(500)]
162        )),
163        Err(err) => MemoryError::Other(format!("Ollama returned HTTP {status}; {err}")),
164    }
165}
166
167/// Parse an Ollama embedding response body into vectors.
168///
169/// Validates that all values are numeric and dimensions match.
170#[doc(hidden)]
171pub fn parse_embedding_response(
172    body: &serde_json::Value,
173    expected_dims: usize,
174) -> Result<Vec<Vec<f32>>, MemoryError> {
175    let embeddings = body["embeddings"].as_array().ok_or_else(|| {
176        MemoryError::Other("Ollama response missing 'embeddings' field".to_string())
177    })?;
178
179    let mut result = Vec::with_capacity(embeddings.len());
180    for embedding_val in embeddings {
181        let raw_array = embedding_val
182            .as_array()
183            .ok_or_else(|| MemoryError::Other("Embedding is not an array".to_string()))?;
184
185        let mut embedding = Vec::with_capacity(raw_array.len());
186        for (i, v) in raw_array.iter().enumerate() {
187            let val = v.as_f64().ok_or_else(|| {
188                MemoryError::Other(format!(
189                    "Embedding dimension {} contains non-numeric value: {}",
190                    i, v
191                ))
192            })?;
193            embedding.push(val as f32);
194        }
195
196        if embedding.len() != expected_dims {
197            return Err(MemoryError::DimensionMismatch {
198                expected: expected_dims,
199                actual: embedding.len(),
200            });
201        }
202
203        result.push(embedding);
204    }
205
206    Ok(result)
207}
208
209// ─── MockEmbedder ────────────────────────────────────────────────
210
211/// Deterministic embedder for unit tests.
212///
213/// Generates consistent embeddings based on a hash of the input text.
214/// Same text always produces the same embedding. Output is normalized.
215pub struct MockEmbedder {
216    dimensions: usize,
217}
218
219impl MockEmbedder {
220    /// Create a new MockEmbedder with the given dimensions.
221    pub fn new(dimensions: usize) -> Self {
222        Self { dimensions }
223    }
224}
225
226impl Embedder for MockEmbedder {
227    fn embed<'a>(&'a self, text: &'a str) -> EmbedFuture<'a> {
228        let embedding = deterministic_embedding(text, self.dimensions);
229        Box::pin(async move { Ok(embedding) })
230    }
231
232    fn embed_batch<'a>(&'a self, texts: Vec<String>) -> EmbedBatchFuture<'a> {
233        let embeddings: Vec<Vec<f32>> = texts
234            .iter()
235            .map(|t| deterministic_embedding(t, self.dimensions))
236            .collect();
237        Box::pin(async move { Ok(embeddings) })
238    }
239
240    fn model_name(&self) -> &str {
241        "mock-embedder"
242    }
243
244    fn dimensions(&self) -> usize {
245        self.dimensions
246    }
247}
248
249/// Generate a deterministic embedding from text using a hash-seeded xorshift RNG.
250fn deterministic_embedding(text: &str, dimensions: usize) -> Vec<f32> {
251    let mut hasher = std::hash::DefaultHasher::new();
252    text.hash(&mut hasher);
253    let mut state = hasher.finish();
254    if state == 0 {
255        state = 1;
256    }
257
258    let mut values = Vec::with_capacity(dimensions);
259    for _ in 0..dimensions {
260        // xorshift64
261        state ^= state << 13;
262        state ^= state >> 7;
263        state ^= state << 17;
264        let val = ((state as f64) / (u64::MAX as f64)) * 2.0 - 1.0;
265        values.push(val as f32);
266    }
267
268    // Normalize to unit length
269    let magnitude: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
270    if magnitude > 0.0 {
271        for v in &mut values {
272            *v /= magnitude;
273        }
274    }
275
276    values
277}