semantic_memory/
embedder.rs1use crate::config::EmbeddingConfig;
7use crate::error::MemoryError;
8use std::future::Future;
9use std::hash::{Hash, Hasher};
10use std::pin::Pin;
11
12pub type EmbedFuture<'a> = Pin<Box<dyn Future<Output = Result<Vec<f32>, MemoryError>> + Send + 'a>>;
14
15pub type EmbedBatchFuture<'a> =
17 Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>, MemoryError>> + Send + 'a>>;
18
19pub trait Embedder: Send + Sync {
23 fn embed<'a>(&'a self, text: &'a str) -> EmbedFuture<'a>;
25
26 fn embed_batch<'a>(&'a self, texts: Vec<String>) -> EmbedBatchFuture<'a>;
30
31 fn model_name(&self) -> &str;
33
34 fn dimensions(&self) -> usize;
36}
37
38pub struct OllamaEmbedder {
42 client: reqwest::Client,
43 base_url: String,
44 model: String,
45 dimensions: usize,
46 batch_size: usize,
47}
48
49impl OllamaEmbedder {
50 pub fn new(config: &EmbeddingConfig) -> Self {
52 let client = reqwest::Client::builder()
53 .timeout(std::time::Duration::from_secs(config.timeout_secs))
54 .build()
55 .expect("Failed to build reqwest client");
56
57 Self {
58 client,
59 base_url: config.ollama_url.trim_end_matches('/').to_string(),
60 model: config.model.clone(),
61 dimensions: config.dimensions,
62 batch_size: config.batch_size,
63 }
64 }
65}
66
67impl Embedder for OllamaEmbedder {
68 fn embed<'a>(&'a self, text: &'a str) -> EmbedFuture<'a> {
69 Box::pin(async move {
70 let mut results = self.embed_batch(vec![text.to_string()]).await?;
71 results.pop().ok_or_else(|| {
72 MemoryError::Other("Ollama returned empty embeddings for single text".to_string())
73 })
74 })
75 }
76
77 fn embed_batch<'a>(&'a self, texts: Vec<String>) -> EmbedBatchFuture<'a> {
78 Box::pin(async move {
79 let mut all_embeddings = Vec::with_capacity(texts.len());
80
81 for batch in texts.chunks(self.batch_size) {
82 let input: Vec<&str> = batch.iter().map(|s| s.as_str()).collect();
83 let body = serde_json::json!({
84 "model": self.model,
85 "input": input
86 });
87
88 let url = format!("{}/api/embed", self.base_url);
89 let response = self
90 .client
91 .post(&url)
92 .json(&body)
93 .send()
94 .await
95 .map_err(|e| {
96 if e.is_connect() {
97 MemoryError::EmbedderUnavailable(format!(
98 "Ollama not running at {}",
99 self.base_url
100 ))
101 } else if e.is_timeout() {
102 MemoryError::EmbedderUnavailable(format!(
103 "Ollama embedding timed out: {}",
104 e
105 ))
106 } else {
107 MemoryError::EmbeddingRequest(e)
108 }
109 })?;
110
111 if response.status() == reqwest::StatusCode::NOT_FOUND {
112 return Err(MemoryError::EmbedderUnavailable(format!(
113 "Model '{}' not available in Ollama. Run: ollama pull {}",
114 self.model, self.model
115 )));
116 }
117
118 if !response.status().is_success() {
119 let status = response.status();
120 let body = response.text().await.unwrap_or_default();
121 return Err(MemoryError::Other(format!(
122 "Ollama returned HTTP {}: {}",
123 status,
124 &body[..body.len().min(500)]
125 )));
126 }
127
128 let resp_body: serde_json::Value = response.json().await?;
129 let batch_embeddings = parse_embedding_response(&resp_body, self.dimensions)?;
130 all_embeddings.extend(batch_embeddings);
131 }
132
133 Ok(all_embeddings)
134 })
135 }
136
137 fn model_name(&self) -> &str {
138 &self.model
139 }
140
141 fn dimensions(&self) -> usize {
142 self.dimensions
143 }
144}
145
146#[doc(hidden)]
150pub fn parse_embedding_response(
151 body: &serde_json::Value,
152 expected_dims: usize,
153) -> Result<Vec<Vec<f32>>, MemoryError> {
154 let embeddings = body["embeddings"].as_array().ok_or_else(|| {
155 MemoryError::Other("Ollama response missing 'embeddings' field".to_string())
156 })?;
157
158 let mut result = Vec::with_capacity(embeddings.len());
159 for embedding_val in embeddings {
160 let raw_array = embedding_val
161 .as_array()
162 .ok_or_else(|| MemoryError::Other("Embedding is not an array".to_string()))?;
163
164 let mut embedding = Vec::with_capacity(raw_array.len());
165 for (i, v) in raw_array.iter().enumerate() {
166 let val = v.as_f64().ok_or_else(|| {
167 MemoryError::Other(format!(
168 "Embedding dimension {} contains non-numeric value: {}",
169 i, v
170 ))
171 })?;
172 embedding.push(val as f32);
173 }
174
175 if embedding.len() != expected_dims {
176 return Err(MemoryError::DimensionMismatch {
177 expected: expected_dims,
178 actual: embedding.len(),
179 });
180 }
181
182 result.push(embedding);
183 }
184
185 Ok(result)
186}
187
188pub struct MockEmbedder {
195 dimensions: usize,
196}
197
198impl MockEmbedder {
199 pub fn new(dimensions: usize) -> Self {
201 Self { dimensions }
202 }
203}
204
205impl Embedder for MockEmbedder {
206 fn embed<'a>(&'a self, text: &'a str) -> EmbedFuture<'a> {
207 let embedding = deterministic_embedding(text, self.dimensions);
208 Box::pin(async move { Ok(embedding) })
209 }
210
211 fn embed_batch<'a>(&'a self, texts: Vec<String>) -> EmbedBatchFuture<'a> {
212 let embeddings: Vec<Vec<f32>> = texts
213 .iter()
214 .map(|t| deterministic_embedding(t, self.dimensions))
215 .collect();
216 Box::pin(async move { Ok(embeddings) })
217 }
218
219 fn model_name(&self) -> &str {
220 "mock-embedder"
221 }
222
223 fn dimensions(&self) -> usize {
224 self.dimensions
225 }
226}
227
228fn deterministic_embedding(text: &str, dimensions: usize) -> Vec<f32> {
230 let mut hasher = std::hash::DefaultHasher::new();
231 text.hash(&mut hasher);
232 let mut state = hasher.finish();
233 if state == 0 {
234 state = 1;
235 }
236
237 let mut values = Vec::with_capacity(dimensions);
238 for _ in 0..dimensions {
239 state ^= state << 13;
241 state ^= state >> 7;
242 state ^= state << 17;
243 let val = ((state as f64) / (u64::MAX as f64)) * 2.0 - 1.0;
244 values.push(val as f32);
245 }
246
247 let magnitude: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
249 if magnitude > 0.0 {
250 for v in &mut values {
251 *v /= magnitude;
252 }
253 }
254
255 values
256}