semantic_memory/
embedder.rs1use crate::config::EmbeddingConfig;
7use crate::error::MemoryError;
8use std::future::Future;
9use std::hash::{Hash, Hasher};
10use std::pin::Pin;
11
12pub type EmbedFuture<'a> = Pin<Box<dyn Future<Output = Result<Vec<f32>, MemoryError>> + Send + 'a>>;
14
15pub type EmbedBatchFuture<'a> =
17 Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, MemoryError>> + Send + 'a>>;
18
19pub trait Embedder: Send + Sync {
23 fn embed<'a>(&'a self, text: &'a str) -> EmbedFuture<'a>;
25
26 fn embed_batch<'a>(&'a self, texts: Vec<String>) -> EmbedBatchFuture<'a>;
30
31 fn model_name(&self) -> &str;
33
34 fn dimensions(&self) -> usize;
36}
37
38pub struct OllamaEmbedder {
42 client: reqwest::Client,
43 base_url: String,
44 model: String,
45 dimensions: usize,
46 batch_size: usize,
47}
48
49impl OllamaEmbedder {
50 pub fn try_new(config: &EmbeddingConfig) -> Result<Self, MemoryError> {
55 let client = reqwest::Client::builder()
56 .timeout(std::time::Duration::from_secs(config.timeout_secs))
57 .build()
58 .map_err(|e| {
59 MemoryError::EmbedderUnavailable(format!("failed to build HTTP client: {e}"))
60 })?;
61
62 Ok(Self {
63 client,
64 base_url: config.ollama_url.trim_end_matches('/').to_string(),
65 model: config.model.clone(),
66 dimensions: config.dimensions,
67 batch_size: config.batch_size,
68 })
69 }
70
71 }
73
74impl Embedder for OllamaEmbedder {
75 fn embed<'a>(&'a self, text: &'a str) -> EmbedFuture<'a> {
76 Box::pin(async move {
77 let mut results = self.embed_batch(vec![text.to_string()]).await?;
78 results.pop().ok_or_else(|| {
79 MemoryError::Other("Ollama returned empty embeddings for single text".to_string())
80 })
81 })
82 }
83
84 fn embed_batch<'a>(&'a self, texts: Vec<String>) -> EmbedBatchFuture<'a> {
85 Box::pin(async move {
86 let mut all_embeddings = Vec::with_capacity(texts.len());
87
88 for batch in texts.chunks(self.batch_size) {
89 let input: Vec<&str> = batch.iter().map(|s| s.as_str()).collect();
90 let body = serde_json::json!({
91 "model": self.model,
92 "input": input
93 });
94
95 let url = format!("{}/api/embed", self.base_url);
96 let response = self
97 .client
98 .post(&url)
99 .json(&body)
100 .send()
101 .await
102 .map_err(|e| {
103 if e.is_connect() {
104 MemoryError::EmbedderUnavailable(format!(
105 "Ollama not running at {}",
106 self.base_url
107 ))
108 } else if e.is_timeout() {
109 MemoryError::EmbedderUnavailable(format!(
110 "Ollama embedding timed out: {}",
111 e
112 ))
113 } else {
114 MemoryError::EmbeddingRequest(e)
115 }
116 })?;
117
118 if response.status() == reqwest::StatusCode::NOT_FOUND {
119 return Err(MemoryError::EmbedderUnavailable(format!(
120 "Model '{}' not available in Ollama. Run: ollama pull {}",
121 self.model, self.model
122 )));
123 }
124
125 if !response.status().is_success() {
126 let status = response.status();
127 let body = response
128 .text()
129 .await
130 .map_err(|err| format!("failed to read Ollama error body: {err}"));
131 return Err(format_ollama_http_error(status, body));
132 }
133
134 let resp_body: serde_json::Value = response.json().await?;
135 let batch_embeddings = parse_embedding_response(&resp_body, self.dimensions)?;
136 all_embeddings.extend(batch_embeddings);
137 }
138
139 Ok(all_embeddings)
140 })
141 }
142
143 fn model_name(&self) -> &str {
144 &self.model
145 }
146
147 fn dimensions(&self) -> usize {
148 self.dimensions
149 }
150}
151
152#[doc(hidden)]
153pub fn format_ollama_http_error(
154 status: reqwest::StatusCode,
155 body: Result<String, String>,
156) -> MemoryError {
157 match body {
158 Ok(body) => MemoryError::Other(format!(
159 "Ollama returned HTTP {}: {}",
160 status,
161 &body[..body.len().min(500)]
162 )),
163 Err(err) => MemoryError::Other(format!("Ollama returned HTTP {status}; {err}")),
164 }
165}
166
167#[doc(hidden)]
171pub fn parse_embedding_response(
172 body: &serde_json::Value,
173 expected_dims: usize,
174) -> Result<Vec<Vec<f32>>, MemoryError> {
175 let embeddings = body["embeddings"].as_array().ok_or_else(|| {
176 MemoryError::Other("Ollama response missing 'embeddings' field".to_string())
177 })?;
178
179 let mut result = Vec::with_capacity(embeddings.len());
180 for embedding_val in embeddings {
181 let raw_array = embedding_val
182 .as_array()
183 .ok_or_else(|| MemoryError::Other("Embedding is not an array".to_string()))?;
184
185 let mut embedding = Vec::with_capacity(raw_array.len());
186 for (i, v) in raw_array.iter().enumerate() {
187 let val = v.as_f64().ok_or_else(|| {
188 MemoryError::Other(format!(
189 "Embedding dimension {} contains non-numeric value: {}",
190 i, v
191 ))
192 })?;
193 embedding.push(val as f32);
194 }
195
196 if embedding.len() != expected_dims {
197 return Err(MemoryError::DimensionMismatch {
198 expected: expected_dims,
199 actual: embedding.len(),
200 });
201 }
202
203 result.push(embedding);
204 }
205
206 Ok(result)
207}
208
209pub struct MockEmbedder {
216 dimensions: usize,
217}
218
219impl MockEmbedder {
220 pub fn new(dimensions: usize) -> Self {
222 Self { dimensions }
223 }
224}
225
226impl Embedder for MockEmbedder {
227 fn embed<'a>(&'a self, text: &'a str) -> EmbedFuture<'a> {
228 let embedding = deterministic_embedding(text, self.dimensions);
229 Box::pin(async move { Ok(embedding) })
230 }
231
232 fn embed_batch<'a>(&'a self, texts: Vec<String>) -> EmbedBatchFuture<'a> {
233 let embeddings: Vec<Vec<f32>> = texts
234 .iter()
235 .map(|t| deterministic_embedding(t, self.dimensions))
236 .collect();
237 Box::pin(async move { Ok(embeddings) })
238 }
239
240 fn model_name(&self) -> &str {
241 "mock-embedder"
242 }
243
244 fn dimensions(&self) -> usize {
245 self.dimensions
246 }
247}
248
249fn deterministic_embedding(text: &str, dimensions: usize) -> Vec<f32> {
251 let mut hasher = std::hash::DefaultHasher::new();
252 text.hash(&mut hasher);
253 let mut state = hasher.finish();
254 if state == 0 {
255 state = 1;
256 }
257
258 let mut values = Vec::with_capacity(dimensions);
259 for _ in 0..dimensions {
260 state ^= state << 13;
262 state ^= state >> 7;
263 state ^= state << 17;
264 let val = ((state as f64) / (u64::MAX as f64)) * 2.0 - 1.0;
265 values.push(val as f32);
266 }
267
268 let magnitude: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
270 if magnitude > 0.0 {
271 for v in &mut values {
272 *v /= magnitude;
273 }
274 }
275
276 values
277}