Skip to main content

symbi_runtime/context/
embedding.rs

1//! Embedding service providers for generating vector embeddings
2//!
3//! Supports Ollama (local) and OpenAI (cloud) embedding providers,
4//! with automatic provider detection from environment variables.
5
6use async_trait::async_trait;
7use std::sync::Arc;
8use std::time::Duration;
9
10use super::types::ContextError;
11use super::vector_db::{EmbeddingService, MockEmbeddingService};
12
13/// Embedding provider selection
14#[derive(Debug, Clone, PartialEq)]
15pub enum EmbeddingProvider {
16    Ollama,
17    OpenAi,
18}
19
20/// Configuration for an embedding service provider
21#[derive(Debug, Clone)]
22pub struct EmbeddingConfig {
23    pub provider: EmbeddingProvider,
24    pub model: String,
25    pub base_url: String,
26    pub api_key: Option<String>,
27    pub dimension: usize,
28    pub timeout_seconds: u64,
29}
30
31impl EmbeddingConfig {
32    /// Resolve embedding configuration from environment variables.
33    ///
34    /// Returns `None` if no provider can be determined (no env vars set),
35    /// which signals the caller to fall back to the mock service.
36    ///
37    /// Resolution order:
38    /// 1. API key: `EMBEDDING_API_KEY` → `OPENAI_API_KEY` → None
39    /// 2. Provider: `EMBEDDING_PROVIDER` explicit, or auto-detect from URL/key
40    /// 3. Per-provider defaults for model, URL, and dimension
41    /// 4. Overrides: `EMBEDDING_MODEL`, `EMBEDDING_API_BASE_URL`, `VECTOR_DIMENSION`
42    pub fn from_env() -> Option<Self> {
43        let api_key = std::env::var("EMBEDDING_API_KEY")
44            .ok()
45            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
46            .filter(|k| !k.is_empty());
47
48        let base_url = std::env::var("EMBEDDING_API_BASE_URL")
49            .ok()
50            .or_else(|| std::env::var("OPENAI_API_BASE_URL").ok())
51            .filter(|u| !u.is_empty());
52
53        let explicit_provider = std::env::var("EMBEDDING_PROVIDER")
54            .ok()
55            .filter(|p| !p.is_empty());
56
57        let provider = if let Some(ref p) = explicit_provider {
58            match p.to_lowercase().as_str() {
59                "ollama" => EmbeddingProvider::Ollama,
60                "openai" => EmbeddingProvider::OpenAi,
61                _ => return None,
62            }
63        } else if let Some(ref url) = base_url {
64            if url.contains("localhost") || url.contains("127.0.0.1") {
65                EmbeddingProvider::Ollama
66            } else if api_key.is_some() {
67                EmbeddingProvider::OpenAi
68            } else {
69                return None;
70            }
71        } else if api_key.is_some() {
72            EmbeddingProvider::OpenAi
73        } else {
74            return None;
75        };
76
77        let (default_model, default_url, default_dim) = match provider {
78            EmbeddingProvider::Ollama => (
79                "nomic-embed-text".to_string(),
80                "http://localhost:11434".to_string(),
81                768,
82            ),
83            EmbeddingProvider::OpenAi => (
84                "text-embedding-3-small".to_string(),
85                "https://api.openai.com/v1".to_string(),
86                1536,
87            ),
88        };
89
90        let model = std::env::var("EMBEDDING_MODEL")
91            .ok()
92            .filter(|m| !m.is_empty())
93            .unwrap_or(default_model);
94
95        let final_url = base_url.unwrap_or(default_url);
96
97        let dimension = std::env::var("VECTOR_DIMENSION")
98            .ok()
99            .and_then(|d| d.parse::<usize>().ok())
100            .unwrap_or(default_dim);
101
102        Some(Self {
103            provider,
104            model,
105            base_url: final_url,
106            api_key,
107            dimension,
108            timeout_seconds: 30,
109        })
110    }
111}
112
113/// Ollama embedding service using the native `/api/embed` endpoint
114pub struct OllamaEmbeddingService {
115    client: reqwest::Client,
116    model: String,
117    base_url: String,
118    dimension: usize,
119}
120
121impl OllamaEmbeddingService {
122    pub fn new(config: &EmbeddingConfig) -> Result<Self, ContextError> {
123        let client = reqwest::Client::builder()
124            .timeout(Duration::from_secs(config.timeout_seconds))
125            .build()
126            .map_err(|e| ContextError::EmbeddingError {
127                reason: format!("Failed to create HTTP client: {e}"),
128            })?;
129
130        Ok(Self {
131            client,
132            model: config.model.clone(),
133            base_url: config.base_url.trim_end_matches('/').to_string(),
134            dimension: config.dimension,
135        })
136    }
137}
138
139#[async_trait]
140impl EmbeddingService for OllamaEmbeddingService {
141    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>, ContextError> {
142        let mut results = self.generate_batch_embeddings(vec![text]).await?;
143        results.pop().ok_or_else(|| ContextError::EmbeddingError {
144            reason: "Empty response from Ollama".to_string(),
145        })
146    }
147
148    async fn generate_batch_embeddings(
149        &self,
150        texts: Vec<&str>,
151    ) -> Result<Vec<Vec<f32>>, ContextError> {
152        let url = format!("{}/api/embed", self.base_url);
153
154        let body = serde_json::json!({
155            "model": self.model,
156            "input": texts,
157        });
158
159        let resp = self
160            .client
161            .post(&url)
162            .json(&body)
163            .send()
164            .await
165            .map_err(|e| ContextError::EmbeddingError {
166                reason: format!("Ollama request failed: {e}"),
167            })?;
168
169        if !resp.status().is_success() {
170            let status = resp.status();
171            let body_text = resp.text().await.unwrap_or_default();
172            return Err(ContextError::EmbeddingError {
173                reason: format!("Ollama returned {status}: {body_text}"),
174            });
175        }
176
177        let json: serde_json::Value =
178            resp.json()
179                .await
180                .map_err(|e| ContextError::EmbeddingError {
181                    reason: format!("Failed to parse Ollama response: {e}"),
182                })?;
183
184        let embeddings = json
185            .get("embeddings")
186            .and_then(|v| v.as_array())
187            .ok_or_else(|| ContextError::EmbeddingError {
188                reason: "Missing 'embeddings' field in Ollama response".to_string(),
189            })?;
190
191        embeddings
192            .iter()
193            .map(|emb| {
194                emb.as_array()
195                    .ok_or_else(|| ContextError::EmbeddingError {
196                        reason: "Invalid embedding array in Ollama response".to_string(),
197                    })?
198                    .iter()
199                    .map(|v| {
200                        v.as_f64()
201                            .map(|f| f as f32)
202                            .ok_or_else(|| ContextError::EmbeddingError {
203                                reason: "Invalid float in embedding".to_string(),
204                            })
205                    })
206                    .collect::<Result<Vec<f32>, _>>()
207            })
208            .collect()
209    }
210
211    fn embedding_dimension(&self) -> usize {
212        self.dimension
213    }
214
215    fn max_text_length(&self) -> usize {
216        8192
217    }
218}
219
220/// OpenAI-compatible embedding service
221pub struct OpenAiEmbeddingService {
222    client: reqwest::Client,
223    model: String,
224    base_url: String,
225    api_key: String,
226    dimension: usize,
227}
228
229impl OpenAiEmbeddingService {
230    pub fn new(config: &EmbeddingConfig) -> Result<Self, ContextError> {
231        let api_key = config
232            .api_key
233            .clone()
234            .filter(|k| !k.is_empty())
235            .ok_or_else(|| ContextError::EmbeddingError {
236                reason: "OpenAI embedding service requires an API key".to_string(),
237            })?;
238
239        let client = reqwest::Client::builder()
240            .timeout(Duration::from_secs(config.timeout_seconds))
241            .build()
242            .map_err(|e| ContextError::EmbeddingError {
243                reason: format!("Failed to create HTTP client: {e}"),
244            })?;
245
246        Ok(Self {
247            client,
248            model: config.model.clone(),
249            base_url: config.base_url.trim_end_matches('/').to_string(),
250            api_key,
251            dimension: config.dimension,
252        })
253    }
254}
255
256#[async_trait]
257impl EmbeddingService for OpenAiEmbeddingService {
258    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>, ContextError> {
259        let mut results = self.generate_batch_embeddings(vec![text]).await?;
260        results.pop().ok_or_else(|| ContextError::EmbeddingError {
261            reason: "Empty response from OpenAI".to_string(),
262        })
263    }
264
265    async fn generate_batch_embeddings(
266        &self,
267        texts: Vec<&str>,
268    ) -> Result<Vec<Vec<f32>>, ContextError> {
269        let url = format!("{}/embeddings", self.base_url);
270
271        let body = serde_json::json!({
272            "model": self.model,
273            "input": texts,
274        });
275
276        let resp = self
277            .client
278            .post(&url)
279            .bearer_auth(&self.api_key)
280            .json(&body)
281            .send()
282            .await
283            .map_err(|e| ContextError::EmbeddingError {
284                reason: format!("OpenAI request failed: {e}"),
285            })?;
286
287        if !resp.status().is_success() {
288            let status = resp.status();
289            let body_text = resp.text().await.unwrap_or_default();
290            return Err(ContextError::EmbeddingError {
291                reason: format!("OpenAI returned {status}: {body_text}"),
292            });
293        }
294
295        let json: serde_json::Value =
296            resp.json()
297                .await
298                .map_err(|e| ContextError::EmbeddingError {
299                    reason: format!("Failed to parse OpenAI response: {e}"),
300                })?;
301
302        // Log token usage
303        if let Some(usage) = json.get("usage") {
304            tracing::debug!(
305                prompt_tokens = usage.get("prompt_tokens").and_then(|v| v.as_u64()),
306                total_tokens = usage.get("total_tokens").and_then(|v| v.as_u64()),
307                "OpenAI embedding token usage"
308            );
309        }
310
311        let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
312            ContextError::EmbeddingError {
313                reason: "Missing 'data' field in OpenAI response".to_string(),
314            }
315        })?;
316
317        // Sort by index to ensure correct ordering
318        let mut indexed: Vec<(usize, Vec<f32>)> = data
319            .iter()
320            .map(|item| {
321                let index = item.get("index").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
322
323                let embedding = item
324                    .get("embedding")
325                    .and_then(|v| v.as_array())
326                    .ok_or_else(|| ContextError::EmbeddingError {
327                        reason: "Missing 'embedding' in OpenAI response item".to_string(),
328                    })?
329                    .iter()
330                    .map(|v| {
331                        v.as_f64()
332                            .map(|f| f as f32)
333                            .ok_or_else(|| ContextError::EmbeddingError {
334                                reason: "Invalid float in embedding".to_string(),
335                            })
336                    })
337                    .collect::<Result<Vec<f32>, _>>()?;
338
339                Ok((index, embedding))
340            })
341            .collect::<Result<Vec<_>, ContextError>>()?;
342
343        indexed.sort_by_key(|(i, _)| *i);
344
345        Ok(indexed.into_iter().map(|(_, emb)| emb).collect())
346    }
347
348    fn embedding_dimension(&self) -> usize {
349        self.dimension
350    }
351
352    fn max_text_length(&self) -> usize {
353        8191 // OpenAI token limit
354    }
355}
356
357/// Create an embedding service from a resolved config.
358pub fn create_embedding_service(
359    config: &EmbeddingConfig,
360) -> Result<Arc<dyn EmbeddingService>, ContextError> {
361    match config.provider {
362        EmbeddingProvider::Ollama => {
363            tracing::info!(
364                model = %config.model,
365                url = %config.base_url,
366                dimension = config.dimension,
367                "Using Ollama embedding service"
368            );
369            Ok(Arc::new(OllamaEmbeddingService::new(config)?))
370        }
371        EmbeddingProvider::OpenAi => {
372            tracing::info!(
373                model = %config.model,
374                url = %config.base_url,
375                dimension = config.dimension,
376                "Using OpenAI embedding service"
377            );
378            Ok(Arc::new(OpenAiEmbeddingService::new(config)?))
379        }
380    }
381}
382
383/// Create an embedding service from environment variables, falling back to
384/// `MockEmbeddingService` when no provider is configured.
385pub fn create_embedding_service_from_env(
386    fallback_dimension: usize,
387) -> Result<Arc<dyn EmbeddingService>, ContextError> {
388    match EmbeddingConfig::from_env() {
389        Some(config) => create_embedding_service(&config),
390        None => {
391            tracing::debug!(
392                dimension = fallback_dimension,
393                "No embedding provider configured, using mock embedding service"
394            );
395            Ok(Arc::new(MockEmbeddingService::new(fallback_dimension)))
396        }
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use serial_test::serial;
404
405    /// Helper: clear all embedding-related env vars before each test
406    fn clear_env() {
407        for var in &[
408            "EMBEDDING_PROVIDER",
409            "EMBEDDING_API_KEY",
410            "OPENAI_API_KEY",
411            "EMBEDDING_API_BASE_URL",
412            "OPENAI_API_BASE_URL",
413            "EMBEDDING_MODEL",
414            "VECTOR_DIMENSION",
415        ] {
416            std::env::remove_var(var);
417        }
418    }
419
420    #[test]
421    #[serial]
422    fn test_embedding_config_defaults_ollama() {
423        clear_env();
424        std::env::set_var("EMBEDDING_PROVIDER", "ollama");
425
426        let config = EmbeddingConfig::from_env().expect("should resolve");
427        assert_eq!(config.provider, EmbeddingProvider::Ollama);
428        assert_eq!(config.model, "nomic-embed-text");
429        assert_eq!(config.base_url, "http://localhost:11434");
430        assert_eq!(config.dimension, 768);
431        assert!(config.api_key.is_none());
432    }
433
434    #[test]
435    #[serial]
436    fn test_embedding_config_defaults_openai() {
437        clear_env();
438        std::env::set_var("EMBEDDING_PROVIDER", "openai");
439        std::env::set_var("OPENAI_API_KEY", "sk-test");
440
441        let config = EmbeddingConfig::from_env().expect("should resolve");
442        assert_eq!(config.provider, EmbeddingProvider::OpenAi);
443        assert_eq!(config.model, "text-embedding-3-small");
444        assert_eq!(config.base_url, "https://api.openai.com/v1");
445        assert_eq!(config.dimension, 1536);
446        assert_eq!(config.api_key.as_deref(), Some("sk-test"));
447    }
448
449    #[test]
450    #[serial]
451    fn test_embedding_config_auto_detect_openai_from_key() {
452        clear_env();
453        std::env::set_var("OPENAI_API_KEY", "sk-auto");
454
455        let config = EmbeddingConfig::from_env().expect("should resolve");
456        assert_eq!(config.provider, EmbeddingProvider::OpenAi);
457        assert_eq!(config.api_key.as_deref(), Some("sk-auto"));
458    }
459
460    #[test]
461    #[serial]
462    fn test_embedding_config_auto_detect_ollama_from_localhost_url() {
463        clear_env();
464        std::env::set_var("EMBEDDING_API_BASE_URL", "http://localhost:11434");
465
466        let config = EmbeddingConfig::from_env().expect("should resolve");
467        assert_eq!(config.provider, EmbeddingProvider::Ollama);
468    }
469
470    #[test]
471    #[serial]
472    fn test_embedding_config_none_when_no_provider() {
473        clear_env();
474        assert!(EmbeddingConfig::from_env().is_none());
475    }
476
477    #[test]
478    #[serial]
479    fn test_embedding_config_dimension_override() {
480        clear_env();
481        std::env::set_var("EMBEDDING_PROVIDER", "ollama");
482        std::env::set_var("VECTOR_DIMENSION", "1024");
483
484        let config = EmbeddingConfig::from_env().expect("should resolve");
485        assert_eq!(config.dimension, 1024);
486    }
487
488    #[test]
489    #[serial]
490    fn test_create_embedding_service_from_env_fallback() {
491        clear_env();
492
493        let svc = create_embedding_service_from_env(256).expect("should return mock");
494        assert_eq!(svc.embedding_dimension(), 256);
495    }
496
497    #[tokio::test]
498    #[serial]
499    async fn test_mock_fallback_generates_embeddings() {
500        clear_env();
501
502        let svc = create_embedding_service_from_env(128).expect("should return mock");
503        let emb = svc.generate_embedding("hello world").await.unwrap();
504        assert_eq!(emb.len(), 128);
505
506        // Verify it's normalized (magnitude ≈ 1.0)
507        let mag: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
508        assert!((mag - 1.0).abs() < 0.01);
509    }
510}