1use async_trait::async_trait;
7use std::sync::Arc;
8use std::time::Duration;
9
10use super::types::ContextError;
11use super::vector_db::{EmbeddingService, MockEmbeddingService};
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum EmbeddingProvider {
16 Ollama,
17 OpenAi,
18}
19
20#[derive(Debug, Clone)]
22pub struct EmbeddingConfig {
23 pub provider: EmbeddingProvider,
24 pub model: String,
25 pub base_url: String,
26 pub api_key: Option<String>,
27 pub dimension: usize,
28 pub timeout_seconds: u64,
29}
30
31impl EmbeddingConfig {
32 pub fn from_env() -> Option<Self> {
43 let api_key = std::env::var("EMBEDDING_API_KEY")
44 .ok()
45 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
46 .filter(|k| !k.is_empty());
47
48 let base_url = std::env::var("EMBEDDING_API_BASE_URL")
49 .ok()
50 .or_else(|| std::env::var("OPENAI_API_BASE_URL").ok())
51 .filter(|u| !u.is_empty());
52
53 let explicit_provider = std::env::var("EMBEDDING_PROVIDER")
54 .ok()
55 .filter(|p| !p.is_empty());
56
57 let provider = if let Some(ref p) = explicit_provider {
58 match p.to_lowercase().as_str() {
59 "ollama" => EmbeddingProvider::Ollama,
60 "openai" => EmbeddingProvider::OpenAi,
61 _ => return None,
62 }
63 } else if let Some(ref url) = base_url {
64 if url.contains("localhost") || url.contains("127.0.0.1") {
65 EmbeddingProvider::Ollama
66 } else if api_key.is_some() {
67 EmbeddingProvider::OpenAi
68 } else {
69 return None;
70 }
71 } else if api_key.is_some() {
72 EmbeddingProvider::OpenAi
73 } else {
74 return None;
75 };
76
77 let (default_model, default_url, default_dim) = match provider {
78 EmbeddingProvider::Ollama => (
79 "nomic-embed-text".to_string(),
80 "http://localhost:11434".to_string(),
81 768,
82 ),
83 EmbeddingProvider::OpenAi => (
84 "text-embedding-3-small".to_string(),
85 "https://api.openai.com/v1".to_string(),
86 1536,
87 ),
88 };
89
90 let model = std::env::var("EMBEDDING_MODEL")
91 .ok()
92 .filter(|m| !m.is_empty())
93 .unwrap_or(default_model);
94
95 let final_url = base_url.unwrap_or(default_url);
96
97 let dimension = std::env::var("VECTOR_DIMENSION")
98 .ok()
99 .and_then(|d| d.parse::<usize>().ok())
100 .unwrap_or(default_dim);
101
102 Some(Self {
103 provider,
104 model,
105 base_url: final_url,
106 api_key,
107 dimension,
108 timeout_seconds: 30,
109 })
110 }
111}
112
113pub struct OllamaEmbeddingService {
115 client: reqwest::Client,
116 model: String,
117 base_url: String,
118 dimension: usize,
119}
120
121impl OllamaEmbeddingService {
122 pub fn new(config: &EmbeddingConfig) -> Result<Self, ContextError> {
123 let client = reqwest::Client::builder()
124 .timeout(Duration::from_secs(config.timeout_seconds))
125 .build()
126 .map_err(|e| ContextError::EmbeddingError {
127 reason: format!("Failed to create HTTP client: {e}"),
128 })?;
129
130 Ok(Self {
131 client,
132 model: config.model.clone(),
133 base_url: config.base_url.trim_end_matches('/').to_string(),
134 dimension: config.dimension,
135 })
136 }
137}
138
139#[async_trait]
140impl EmbeddingService for OllamaEmbeddingService {
141 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>, ContextError> {
142 let mut results = self.generate_batch_embeddings(vec![text]).await?;
143 results.pop().ok_or_else(|| ContextError::EmbeddingError {
144 reason: "Empty response from Ollama".to_string(),
145 })
146 }
147
148 async fn generate_batch_embeddings(
149 &self,
150 texts: Vec<&str>,
151 ) -> Result<Vec<Vec<f32>>, ContextError> {
152 let url = format!("{}/api/embed", self.base_url);
153
154 let body = serde_json::json!({
155 "model": self.model,
156 "input": texts,
157 });
158
159 let resp = self
160 .client
161 .post(&url)
162 .json(&body)
163 .send()
164 .await
165 .map_err(|e| ContextError::EmbeddingError {
166 reason: format!("Ollama request failed: {e}"),
167 })?;
168
169 if !resp.status().is_success() {
170 let status = resp.status();
171 let body_text = resp.text().await.unwrap_or_default();
172 return Err(ContextError::EmbeddingError {
173 reason: format!("Ollama returned {status}: {body_text}"),
174 });
175 }
176
177 let json: serde_json::Value =
178 resp.json()
179 .await
180 .map_err(|e| ContextError::EmbeddingError {
181 reason: format!("Failed to parse Ollama response: {e}"),
182 })?;
183
184 let embeddings = json
185 .get("embeddings")
186 .and_then(|v| v.as_array())
187 .ok_or_else(|| ContextError::EmbeddingError {
188 reason: "Missing 'embeddings' field in Ollama response".to_string(),
189 })?;
190
191 embeddings
192 .iter()
193 .map(|emb| {
194 emb.as_array()
195 .ok_or_else(|| ContextError::EmbeddingError {
196 reason: "Invalid embedding array in Ollama response".to_string(),
197 })?
198 .iter()
199 .map(|v| {
200 v.as_f64()
201 .map(|f| f as f32)
202 .ok_or_else(|| ContextError::EmbeddingError {
203 reason: "Invalid float in embedding".to_string(),
204 })
205 })
206 .collect::<Result<Vec<f32>, _>>()
207 })
208 .collect()
209 }
210
211 fn embedding_dimension(&self) -> usize {
212 self.dimension
213 }
214
215 fn max_text_length(&self) -> usize {
216 8192
217 }
218}
219
220pub struct OpenAiEmbeddingService {
222 client: reqwest::Client,
223 model: String,
224 base_url: String,
225 api_key: String,
226 dimension: usize,
227}
228
229impl OpenAiEmbeddingService {
230 pub fn new(config: &EmbeddingConfig) -> Result<Self, ContextError> {
231 let api_key = config
232 .api_key
233 .clone()
234 .filter(|k| !k.is_empty())
235 .ok_or_else(|| ContextError::EmbeddingError {
236 reason: "OpenAI embedding service requires an API key".to_string(),
237 })?;
238
239 let client = reqwest::Client::builder()
240 .timeout(Duration::from_secs(config.timeout_seconds))
241 .build()
242 .map_err(|e| ContextError::EmbeddingError {
243 reason: format!("Failed to create HTTP client: {e}"),
244 })?;
245
246 Ok(Self {
247 client,
248 model: config.model.clone(),
249 base_url: config.base_url.trim_end_matches('/').to_string(),
250 api_key,
251 dimension: config.dimension,
252 })
253 }
254}
255
256#[async_trait]
257impl EmbeddingService for OpenAiEmbeddingService {
258 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>, ContextError> {
259 let mut results = self.generate_batch_embeddings(vec![text]).await?;
260 results.pop().ok_or_else(|| ContextError::EmbeddingError {
261 reason: "Empty response from OpenAI".to_string(),
262 })
263 }
264
265 async fn generate_batch_embeddings(
266 &self,
267 texts: Vec<&str>,
268 ) -> Result<Vec<Vec<f32>>, ContextError> {
269 let url = format!("{}/embeddings", self.base_url);
270
271 let body = serde_json::json!({
272 "model": self.model,
273 "input": texts,
274 });
275
276 let resp = self
277 .client
278 .post(&url)
279 .bearer_auth(&self.api_key)
280 .json(&body)
281 .send()
282 .await
283 .map_err(|e| ContextError::EmbeddingError {
284 reason: format!("OpenAI request failed: {e}"),
285 })?;
286
287 if !resp.status().is_success() {
288 let status = resp.status();
289 let body_text = resp.text().await.unwrap_or_default();
290 return Err(ContextError::EmbeddingError {
291 reason: format!("OpenAI returned {status}: {body_text}"),
292 });
293 }
294
295 let json: serde_json::Value =
296 resp.json()
297 .await
298 .map_err(|e| ContextError::EmbeddingError {
299 reason: format!("Failed to parse OpenAI response: {e}"),
300 })?;
301
302 if let Some(usage) = json.get("usage") {
304 tracing::debug!(
305 prompt_tokens = usage.get("prompt_tokens").and_then(|v| v.as_u64()),
306 total_tokens = usage.get("total_tokens").and_then(|v| v.as_u64()),
307 "OpenAI embedding token usage"
308 );
309 }
310
311 let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
312 ContextError::EmbeddingError {
313 reason: "Missing 'data' field in OpenAI response".to_string(),
314 }
315 })?;
316
317 let mut indexed: Vec<(usize, Vec<f32>)> = data
319 .iter()
320 .map(|item| {
321 let index = item.get("index").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
322
323 let embedding = item
324 .get("embedding")
325 .and_then(|v| v.as_array())
326 .ok_or_else(|| ContextError::EmbeddingError {
327 reason: "Missing 'embedding' in OpenAI response item".to_string(),
328 })?
329 .iter()
330 .map(|v| {
331 v.as_f64()
332 .map(|f| f as f32)
333 .ok_or_else(|| ContextError::EmbeddingError {
334 reason: "Invalid float in embedding".to_string(),
335 })
336 })
337 .collect::<Result<Vec<f32>, _>>()?;
338
339 Ok((index, embedding))
340 })
341 .collect::<Result<Vec<_>, ContextError>>()?;
342
343 indexed.sort_by_key(|(i, _)| *i);
344
345 Ok(indexed.into_iter().map(|(_, emb)| emb).collect())
346 }
347
348 fn embedding_dimension(&self) -> usize {
349 self.dimension
350 }
351
352 fn max_text_length(&self) -> usize {
353 8191 }
355}
356
357pub fn create_embedding_service(
359 config: &EmbeddingConfig,
360) -> Result<Arc<dyn EmbeddingService>, ContextError> {
361 match config.provider {
362 EmbeddingProvider::Ollama => {
363 tracing::info!(
364 model = %config.model,
365 url = %config.base_url,
366 dimension = config.dimension,
367 "Using Ollama embedding service"
368 );
369 Ok(Arc::new(OllamaEmbeddingService::new(config)?))
370 }
371 EmbeddingProvider::OpenAi => {
372 tracing::info!(
373 model = %config.model,
374 url = %config.base_url,
375 dimension = config.dimension,
376 "Using OpenAI embedding service"
377 );
378 Ok(Arc::new(OpenAiEmbeddingService::new(config)?))
379 }
380 }
381}
382
383pub fn create_embedding_service_from_env(
386 fallback_dimension: usize,
387) -> Result<Arc<dyn EmbeddingService>, ContextError> {
388 match EmbeddingConfig::from_env() {
389 Some(config) => create_embedding_service(&config),
390 None => {
391 tracing::debug!(
392 dimension = fallback_dimension,
393 "No embedding provider configured, using mock embedding service"
394 );
395 Ok(Arc::new(MockEmbeddingService::new(fallback_dimension)))
396 }
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use serial_test::serial;
404
405 fn clear_env() {
407 for var in &[
408 "EMBEDDING_PROVIDER",
409 "EMBEDDING_API_KEY",
410 "OPENAI_API_KEY",
411 "EMBEDDING_API_BASE_URL",
412 "OPENAI_API_BASE_URL",
413 "EMBEDDING_MODEL",
414 "VECTOR_DIMENSION",
415 ] {
416 std::env::remove_var(var);
417 }
418 }
419
420 #[test]
421 #[serial]
422 fn test_embedding_config_defaults_ollama() {
423 clear_env();
424 std::env::set_var("EMBEDDING_PROVIDER", "ollama");
425
426 let config = EmbeddingConfig::from_env().expect("should resolve");
427 assert_eq!(config.provider, EmbeddingProvider::Ollama);
428 assert_eq!(config.model, "nomic-embed-text");
429 assert_eq!(config.base_url, "http://localhost:11434");
430 assert_eq!(config.dimension, 768);
431 assert!(config.api_key.is_none());
432 }
433
434 #[test]
435 #[serial]
436 fn test_embedding_config_defaults_openai() {
437 clear_env();
438 std::env::set_var("EMBEDDING_PROVIDER", "openai");
439 std::env::set_var("OPENAI_API_KEY", "sk-test");
440
441 let config = EmbeddingConfig::from_env().expect("should resolve");
442 assert_eq!(config.provider, EmbeddingProvider::OpenAi);
443 assert_eq!(config.model, "text-embedding-3-small");
444 assert_eq!(config.base_url, "https://api.openai.com/v1");
445 assert_eq!(config.dimension, 1536);
446 assert_eq!(config.api_key.as_deref(), Some("sk-test"));
447 }
448
449 #[test]
450 #[serial]
451 fn test_embedding_config_auto_detect_openai_from_key() {
452 clear_env();
453 std::env::set_var("OPENAI_API_KEY", "sk-auto");
454
455 let config = EmbeddingConfig::from_env().expect("should resolve");
456 assert_eq!(config.provider, EmbeddingProvider::OpenAi);
457 assert_eq!(config.api_key.as_deref(), Some("sk-auto"));
458 }
459
460 #[test]
461 #[serial]
462 fn test_embedding_config_auto_detect_ollama_from_localhost_url() {
463 clear_env();
464 std::env::set_var("EMBEDDING_API_BASE_URL", "http://localhost:11434");
465
466 let config = EmbeddingConfig::from_env().expect("should resolve");
467 assert_eq!(config.provider, EmbeddingProvider::Ollama);
468 }
469
470 #[test]
471 #[serial]
472 fn test_embedding_config_none_when_no_provider() {
473 clear_env();
474 assert!(EmbeddingConfig::from_env().is_none());
475 }
476
477 #[test]
478 #[serial]
479 fn test_embedding_config_dimension_override() {
480 clear_env();
481 std::env::set_var("EMBEDDING_PROVIDER", "ollama");
482 std::env::set_var("VECTOR_DIMENSION", "1024");
483
484 let config = EmbeddingConfig::from_env().expect("should resolve");
485 assert_eq!(config.dimension, 1024);
486 }
487
488 #[test]
489 #[serial]
490 fn test_create_embedding_service_from_env_fallback() {
491 clear_env();
492
493 let svc = create_embedding_service_from_env(256).expect("should return mock");
494 assert_eq!(svc.embedding_dimension(), 256);
495 }
496
497 #[tokio::test]
498 #[serial]
499 async fn test_mock_fallback_generates_embeddings() {
500 clear_env();
501
502 let svc = create_embedding_service_from_env(128).expect("should return mock");
503 let emb = svc.generate_embedding("hello world").await.unwrap();
504 assert_eq!(emb.len(), 128);
505
506 let mag: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
508 assert!((mag - 1.0).abs() < 0.01);
509 }
510}