1use anyhow::{anyhow, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18use tokio::sync::RwLock;
19
20pub const DEFAULT_EMBEDDING_MODEL: &str = "text-embedding-ada-002";
26
27pub const DEFAULT_EMBEDDING_DIMENSION: usize = 1536;
29
30pub const DEFAULT_CACHE_TTL_SECS: u64 = 3600;
32
33pub const DEFAULT_CACHE_MAX_ENTRIES: usize = 10000;
35
36#[async_trait]
44pub trait EmbeddingModel: Send + Sync {
45 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
47
48 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
50
51 fn dimension(&self) -> usize;
53
54 fn model_name(&self) -> &str;
56
57 fn provider(&self) -> &str;
59}
60
61#[derive(Debug, Clone)]
67struct CacheEntry {
68 embedding: Vec<f32>,
70 created_at: Instant,
72 access_count: usize,
74}
75
76#[derive(Debug)]
80pub struct EmbeddingCache {
81 store: RwLock<HashMap<String, CacheEntry>>,
83 max_entries: usize,
85 ttl_secs: u64,
87}
88
89impl EmbeddingCache {
90 pub fn new(max_entries: usize, ttl_secs: u64) -> Self {
92 Self {
93 store: RwLock::new(HashMap::new()),
94 max_entries,
95 ttl_secs,
96 }
97 }
98
99 pub fn default_cache() -> Self {
101 Self::new(DEFAULT_CACHE_MAX_ENTRIES, DEFAULT_CACHE_TTL_SECS)
102 }
103
104 fn cache_key(provider: &str, model: &str, text: &str) -> String {
106 use std::collections::hash_map::DefaultHasher;
107 use std::hash::{Hash, Hasher};
108
109 let mut hasher = DefaultHasher::new();
110 provider.hash(&mut hasher);
111 model.hash(&mut hasher);
112 text.hash(&mut hasher);
113 format!("{}:{}:{:016x}", provider, model, hasher.finish())
114 }
115
116 pub async fn get(&self, provider: &str, model: &str, text: &str) -> Option<Vec<f32>> {
118 let key = Self::cache_key(provider, model, text);
119 let mut store = self.store.write().await;
120
121 if let Some(entry) = store.get_mut(&key) {
122 if entry.created_at.elapsed() > Duration::from_secs(self.ttl_secs) {
124 store.remove(&key);
125 return None;
126 }
127
128 entry.access_count += 1;
129 return Some(entry.embedding.clone());
130 }
131
132 None
133 }
134
135 pub async fn put(&self, provider: &str, model: &str, text: &str, embedding: Vec<f32>) {
137 let key = Self::cache_key(provider, model, text);
138 let mut store = self.store.write().await;
139
140 if store.len() >= self.max_entries {
142 if let Some((lru_key, _)) = store
143 .iter()
144 .min_by_key(|(_, e)| e.access_count)
145 .map(|(k, v)| (k.clone(), v.access_count))
146 {
147 store.remove(&lru_key);
148 }
149 }
150
151 store.insert(
152 key,
153 CacheEntry {
154 embedding,
155 created_at: Instant::now(),
156 access_count: 0,
157 },
158 );
159 }
160
161 pub async fn get_batch(
163 &self,
164 provider: &str,
165 model: &str,
166 texts: &[String],
167 ) -> Vec<Option<Vec<f32>>> {
168 let mut results = Vec::with_capacity(texts.len());
169 let mut store = self.store.write().await;
170
171 for text in texts {
172 let key = Self::cache_key(provider, model, text);
173
174 if let Some(entry) = store.get_mut(&key) {
175 if entry.created_at.elapsed() > Duration::from_secs(self.ttl_secs) {
176 store.remove(&key);
177 results.push(None);
178 } else {
179 entry.access_count += 1;
180 results.push(Some(entry.embedding.clone()));
181 }
182 } else {
183 results.push(None);
184 }
185 }
186
187 results
188 }
189
190 pub async fn clear(&self) {
192 let mut store = self.store.write().await;
193 store.clear();
194 }
195
196 pub async fn stats(&self) -> CacheStats {
198 let store = self.store.read().await;
199 let total_entries = store.len();
200 let total_access: usize = store.values().map(|e| e.access_count).sum();
201
202 CacheStats {
203 total_entries,
204 total_access,
205 max_entries: self.max_entries,
206 ttl_secs: self.ttl_secs,
207 }
208 }
209}
210
211#[derive(Debug, Clone)]
213pub struct CacheStats {
214 pub total_entries: usize,
215 pub total_access: usize,
216 pub max_entries: usize,
217 pub ttl_secs: u64,
218}
219
220#[derive(Debug, Clone, PartialEq, Eq)]
226pub enum EmbeddingProvider {
227 OpenAI,
228 HuggingFace,
229 Cohere,
230 Local,
231 Mock,
233}
234
235impl EmbeddingProvider {
236 pub fn as_str(&self) -> &'static str {
237 match self {
238 Self::OpenAI => "openai",
239 Self::HuggingFace => "huggingface",
240 Self::Cohere => "cohere",
241 Self::Local => "local",
242 Self::Mock => "mock",
243 }
244 }
245}
246
247#[derive(Debug, Clone)]
249pub struct EmbeddingsConfig {
250 pub provider: EmbeddingProvider,
252 pub api_key: String,
254 pub base_url: Option<String>,
256 pub model: String,
258 pub dimension: Option<usize>,
260}
261
262impl Default for EmbeddingsConfig {
263 fn default() -> Self {
264 Self {
267 provider: EmbeddingProvider::Mock,
268 api_key: String::new(),
269 base_url: None,
270 model: "mock-embedding".to_string(),
271 dimension: Some(DEFAULT_EMBEDDING_DIMENSION),
272 }
273 }
274}
275
276impl EmbeddingsConfig {
277 pub fn openai_from_env() -> Result<Self> {
279 let api_key = std::env::var("OPENAI_API_KEY")
280 .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
281
282 let base_url = std::env::var("OPENAI_BASE_URL")
283 .ok()
284 .or_else(|| Some("https://api.openai.com/v1".to_string()));
285
286 let model = std::env::var("OPENAI_EMBEDDING_MODEL")
287 .unwrap_or_else(|_| DEFAULT_EMBEDDING_MODEL.to_string());
288
289 Ok(Self {
290 provider: EmbeddingProvider::OpenAI,
291 api_key,
292 base_url,
293 model,
294 dimension: None,
295 })
296 }
297
298 pub fn huggingface_from_env() -> Result<Self> {
300 let api_key = std::env::var("HUGGINGFACE_API_KEY")
301 .map_err(|_| anyhow!("HUGGINGFACE_API_KEY environment variable not set"))?;
302
303 let model = std::env::var("HUGGINGFACE_EMBEDDING_MODEL")
304 .unwrap_or_else(|_| "sentence-transformers/all-MiniLM-L6-v2".to_string());
305
306 Ok(Self {
307 provider: EmbeddingProvider::HuggingFace,
308 api_key,
309 base_url: Some(
310 "https://api-inference.huggingface.co/pipeline/feature-extraction".to_string(),
311 ),
312 model,
313 dimension: None,
314 })
315 }
316
317 pub fn cohere_from_env() -> Result<Self> {
319 let api_key = std::env::var("COHERE_API_KEY")
320 .map_err(|_| anyhow!("COHERE_API_KEY environment variable not set"))?;
321
322 let model = std::env::var("COHERE_EMBEDDING_MODEL")
323 .unwrap_or_else(|_| "embed-english-v3.0".to_string());
324
325 Ok(Self {
326 provider: EmbeddingProvider::Cohere,
327 api_key,
328 base_url: Some("https://api.cohere.ai/v1".to_string()),
329 model,
330 dimension: None,
331 })
332 }
333
334 pub fn local(model: impl Into<String>, dimension: Option<usize>) -> Self {
336 Self {
337 provider: EmbeddingProvider::Local,
338 api_key: String::new(),
339 base_url: None,
340 model: model.into(),
341 dimension,
342 }
343 }
344
345 pub fn is_valid(&self) -> bool {
347 matches!(
348 self.provider,
349 EmbeddingProvider::Local | EmbeddingProvider::Mock
350 ) || !self.api_key.is_empty()
351 }
352}
353
354#[derive(Debug)]
360pub struct OpenAIEmbeddings {
361 client: Client,
362 config: EmbeddingsConfig,
363 cache: Option<Arc<EmbeddingCache>>,
364}
365
366impl OpenAIEmbeddings {
367 pub fn new(config: EmbeddingsConfig) -> Result<Self> {
368 if !config.is_valid() {
369 return Err(anyhow!("OpenAI Embeddings API not configured"));
370 }
371
372 Ok(Self {
373 client: Client::new(),
374 config,
375 cache: None,
376 })
377 }
378
379 pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
380 let mut embeddings = Self::new(config)?;
381 embeddings.cache = Some(cache);
382 Ok(embeddings)
383 }
384
385 fn base_url(&self) -> &str {
386 self.config
387 .base_url
388 .as_deref()
389 .unwrap_or("https://api.openai.com/v1")
390 }
391}
392
393#[async_trait]
394impl EmbeddingModel for OpenAIEmbeddings {
395 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
396 let embeddings = self.embed_batch(&[text.to_string()]).await?;
397 embeddings
398 .into_iter()
399 .next()
400 .ok_or_else(|| anyhow!("No embedding returned"))
401 }
402
403 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
404 if texts.is_empty() {
405 return Ok(Vec::new());
406 }
407
408 if let Some(cache) = &self.cache {
410 let cached = cache.get_batch("openai", &self.config.model, texts).await;
411 let all_cached = cached.iter().all(|c| c.is_some());
412 if all_cached {
413 return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
414 }
415 }
416
417 let url = format!("{}/embeddings", self.base_url());
418
419 let request_body = OpenAiEmbeddingRequest {
420 model: self.config.model.clone(),
421 input: texts.to_vec(),
422 encoding_format: Some("float".to_string()),
423 };
424
425 tracing::debug!("Sending OpenAI embedding request for {} texts", texts.len());
426
427 let response = self
428 .client
429 .post(&url)
430 .header("Authorization", format!("Bearer {}", self.config.api_key))
431 .header("Content-Type", "application/json")
432 .json(&request_body)
433 .send()
434 .await?;
435
436 let status = response.status();
437 let response_text = response.text().await?;
438
439 if !status.is_success() {
440 tracing::error!("OpenAI Embedding API error: {} - {}", status, response_text);
441 return Err(anyhow!(
442 "OpenAI Embedding API request failed with status {}: {}",
443 status,
444 response_text
445 ));
446 }
447
448 let response_body: OpenAiEmbeddingResponse =
449 serde_json::from_str(&response_text).map_err(|e| {
450 anyhow!(
451 "Failed to parse OpenAI embedding response: {} - {}",
452 e,
453 response_text
454 )
455 })?;
456
457 let mut embeddings: Vec<(usize, Vec<f32>)> = response_body
459 .data
460 .into_iter()
461 .map(|item| (item.index, item.embedding))
462 .collect();
463 embeddings.sort_by_key(|(idx, _)| *idx);
464 let result: Vec<Vec<f32>> = embeddings.into_iter().map(|(_, emb)| emb).collect();
465
466 if let Some(cache) = &self.cache {
468 for (text, embedding) in texts.iter().zip(result.iter()) {
469 cache
470 .put("openai", &self.config.model, text, embedding.clone())
471 .await;
472 }
473 }
474
475 Ok(result)
476 }
477
478 fn dimension(&self) -> usize {
479 match self.config.model.as_str() {
480 "text-embedding-ada-002" => 1536,
481 "text-embedding-3-small" => 1536,
482 "text-embedding-3-large" => 3072,
483 _ => DEFAULT_EMBEDDING_DIMENSION,
484 }
485 }
486
487 fn model_name(&self) -> &str {
488 &self.config.model
489 }
490
491 fn provider(&self) -> &str {
492 "openai"
493 }
494}
495
496#[derive(Serialize)]
497struct OpenAiEmbeddingRequest {
498 model: String,
499 input: Vec<String>,
500 #[serde(skip_serializing_if = "Option::is_none")]
501 encoding_format: Option<String>,
502}
503
504#[derive(Deserialize)]
505struct OpenAiEmbeddingResponse {
506 data: Vec<OpenAiEmbeddingData>,
507 #[allow(dead_code)]
508 model: String,
509 #[allow(dead_code)]
510 usage: OpenAiEmbeddingUsage,
511}
512
513#[derive(Deserialize)]
514struct OpenAiEmbeddingData {
515 embedding: Vec<f32>,
516 index: usize,
517 #[allow(dead_code)]
518 object: String,
519}
520
521#[derive(Deserialize)]
522#[allow(dead_code)]
523struct OpenAiEmbeddingUsage {
524 prompt_tokens: u32,
525 total_tokens: u32,
526}
527
528#[derive(Debug)]
534pub struct HuggingFaceEmbeddings {
535 client: Client,
536 config: EmbeddingsConfig,
537 cache: Option<Arc<EmbeddingCache>>,
538}
539
540impl HuggingFaceEmbeddings {
541 pub fn new(config: EmbeddingsConfig) -> Result<Self> {
542 if !config.is_valid() {
543 return Err(anyhow!("HuggingFace API not configured"));
544 }
545
546 Ok(Self {
547 client: Client::new(),
548 config,
549 cache: None,
550 })
551 }
552
553 pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
554 let mut embeddings = Self::new(config)?;
555 embeddings.cache = Some(cache);
556 Ok(embeddings)
557 }
558}
559
560#[async_trait]
561impl EmbeddingModel for HuggingFaceEmbeddings {
562 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
563 let embeddings = self.embed_batch(&[text.to_string()]).await?;
565 embeddings
566 .into_iter()
567 .next()
568 .ok_or_else(|| anyhow!("No embedding returned from HuggingFace"))
569 }
570
571 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
572 if texts.is_empty() {
573 return Ok(Vec::new());
574 }
575
576 if let Some(cache) = &self.cache {
578 let cached = cache
579 .get_batch("huggingface", &self.config.model, texts)
580 .await;
581 let all_cached = cached.iter().all(|c| c.is_some());
582 if all_cached {
583 return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
584 }
585 }
586
587 let url = format!(
588 "https://api-inference.huggingface.co/pipeline/feature-extraction/{}",
589 self.config.model
590 );
591
592 tracing::debug!(
593 "Sending HuggingFace embedding request for {} texts",
594 texts.len()
595 );
596
597 let response = self
598 .client
599 .post(&url)
600 .header("Authorization", format!("Bearer {}", self.config.api_key))
601 .header("Content-Type", "application/json")
602 .json(&serde_json::json!({ "inputs": texts }))
603 .send()
604 .await?;
605
606 let status = response.status();
607 let response_text = response.text().await?;
608
609 if !status.is_success() {
610 tracing::error!("HuggingFace API error: {} - {}", status, response_text);
611 return Err(anyhow!(
612 "HuggingFace API request failed with status {}: {}",
613 status,
614 response_text
615 ));
616 }
617
618 let embeddings: Vec<Vec<f32>> = serde_json::from_str(&response_text).map_err(|e| {
620 anyhow!(
621 "Failed to parse HuggingFace response: {} - {}",
622 e,
623 response_text
624 )
625 })?;
626
627 if let Some(cache) = &self.cache {
629 for (text, embedding) in texts.iter().zip(embeddings.iter()) {
630 cache
631 .put("huggingface", &self.config.model, text, embedding.clone())
632 .await;
633 }
634 }
635
636 Ok(embeddings)
637 }
638
639 fn dimension(&self) -> usize {
640 match self.config.model.as_str() {
642 "sentence-transformers/all-MiniLM-L6-v2" => 384,
643 "sentence-transformers/all-mpnet-base-v2" => 768,
644 "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" => 384,
645 _ => self.config.dimension.unwrap_or(768),
646 }
647 }
648
649 fn model_name(&self) -> &str {
650 &self.config.model
651 }
652
653 fn provider(&self) -> &str {
654 "huggingface"
655 }
656}
657
658#[derive(Debug)]
664pub struct CohereEmbeddings {
665 client: Client,
666 config: EmbeddingsConfig,
667 cache: Option<Arc<EmbeddingCache>>,
668}
669
670impl CohereEmbeddings {
671 pub fn new(config: EmbeddingsConfig) -> Result<Self> {
672 if !config.is_valid() {
673 return Err(anyhow!("Cohere API not configured"));
674 }
675
676 Ok(Self {
677 client: Client::new(),
678 config,
679 cache: None,
680 })
681 }
682
683 pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
684 let mut embeddings = Self::new(config)?;
685 embeddings.cache = Some(cache);
686 Ok(embeddings)
687 }
688}
689
690#[async_trait]
691impl EmbeddingModel for CohereEmbeddings {
692 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
693 let embeddings = self.embed_batch(&[text.to_string()]).await?;
694 embeddings
695 .into_iter()
696 .next()
697 .ok_or_else(|| anyhow!("No embedding returned from Cohere"))
698 }
699
700 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
701 if texts.is_empty() {
702 return Ok(Vec::new());
703 }
704
705 if let Some(cache) = &self.cache {
707 let cached = cache.get_batch("cohere", &self.config.model, texts).await;
708 let all_cached = cached.iter().all(|c| c.is_some());
709 if all_cached {
710 return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
711 }
712 }
713
714 let url = "https://api.cohere.ai/v1/embed";
715
716 let request_body = CohereEmbeddingRequest {
717 model: self.config.model.clone(),
718 texts: texts.to_vec(),
719 input_type: "search_document",
720 embedding_types: Some(vec!["float".to_string()]),
721 };
722
723 tracing::debug!("Sending Cohere embedding request for {} texts", texts.len());
724
725 let response = self
726 .client
727 .post(url)
728 .header("Authorization", format!("Bearer {}", self.config.api_key))
729 .header("Content-Type", "application/json")
730 .json(&request_body)
731 .send()
732 .await?;
733
734 let status = response.status();
735 let response_text = response.text().await?;
736
737 if !status.is_success() {
738 tracing::error!("Cohere API error: {} - {}", status, response_text);
739 return Err(anyhow!(
740 "Cohere API request failed with status {}: {}",
741 status,
742 response_text
743 ));
744 }
745
746 let response_body: CohereEmbeddingResponse = serde_json::from_str(&response_text)
747 .map_err(|e| anyhow!("Failed to parse Cohere response: {} - {}", e, response_text))?;
748
749 let result = response_body.embeddings.float;
750
751 if let Some(cache) = &self.cache {
753 for (text, embedding) in texts.iter().zip(result.iter()) {
754 cache
755 .put("cohere", &self.config.model, text, embedding.clone())
756 .await;
757 }
758 }
759
760 Ok(result)
761 }
762
763 fn dimension(&self) -> usize {
764 match self.config.model.as_str() {
765 "embed-english-v3.0" | "embed-english-light-v3.0" => 1024,
766 "embed-multilingual-v3.0" => 1024,
767 "embed-english-v2.0" => 4096,
768 _ => self.config.dimension.unwrap_or(1024),
769 }
770 }
771
772 fn model_name(&self) -> &str {
773 &self.config.model
774 }
775
776 fn provider(&self) -> &str {
777 "cohere"
778 }
779}
780
781#[derive(Serialize)]
782struct CohereEmbeddingRequest {
783 model: String,
784 texts: Vec<String>,
785 input_type: &'static str,
786 #[serde(skip_serializing_if = "Option::is_none")]
787 embedding_types: Option<Vec<String>>,
788}
789
790#[derive(Deserialize)]
791struct CohereEmbeddingResponse {
792 embeddings: CohereEmbeddingsData,
793 #[allow(dead_code)]
794 id: String,
795 #[allow(dead_code)]
796 text_type: String,
797}
798
799#[derive(Deserialize)]
800struct CohereEmbeddingsData {
801 float: Vec<Vec<f32>>,
802}
803
804pub struct LocalEmbeddings {
813 config: EmbeddingsConfig,
814 cache: Option<Arc<EmbeddingCache>>,
815 #[cfg(feature = "local-embeddings")]
816 #[allow(dead_code)]
817 model: Option<std::sync::Mutex<Box<dyn LocalModelBackend>>>,
818}
819
820impl std::fmt::Debug for LocalEmbeddings {
821 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
822 f.debug_struct("LocalEmbeddings")
823 .field("config", &self.config)
824 .field("cache", &self.cache)
825 .field("model", &"<model>")
826 .finish()
827 }
828}
829
830impl LocalEmbeddings {
831 pub fn new(config: EmbeddingsConfig) -> Result<Self> {
832 Ok(Self {
833 config,
834 cache: None,
835 #[cfg(feature = "local-embeddings")]
836 model: None,
837 })
838 }
839
840 pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
841 let mut embeddings = Self::new(config)?;
842 embeddings.cache = Some(cache);
843 Ok(embeddings)
844 }
845
846 #[cfg(feature = "local-embeddings")]
848 pub fn load_model(&mut self) -> Result<()> {
849 tracing::info!("Loading local embedding model: {}", self.config.model);
852 Ok(())
853 }
854}
855
856#[async_trait]
857impl EmbeddingModel for LocalEmbeddings {
858 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
859 if let Some(cache) = &self.cache {
861 if let Some(embedding) = cache.get("local", &self.config.model, text).await {
862 return Ok(embedding);
863 }
864 }
865
866 #[cfg(feature = "local-embeddings")]
867 {
868 let embedding = vec![0.0f32; self.dimension()];
871
872 if let Some(cache) = &self.cache {
873 cache
874 .put("local", &self.config.model, text, embedding.clone())
875 .await;
876 }
877
878 Ok(embedding)
879 }
880
881 #[cfg(not(feature = "local-embeddings"))]
882 {
883 Err(anyhow!(
884 "Local embeddings require 'local-embeddings' feature. \
885 Enable it in Cargo.toml and ensure candle or ort is available."
886 ))
887 }
888 }
889
890 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
891 if let Some(cache) = &self.cache {
893 let cached = cache.get_batch("local", &self.config.model, texts).await;
894 if cached.iter().all(|c| c.is_some()) {
895 return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
896 }
897 }
898
899 #[cfg(feature = "local-embeddings")]
900 {
901 let mut results = Vec::with_capacity(texts.len());
902 for text in texts {
903 results.push(self.embed(text).await?);
904 }
905
906 if let Some(cache) = &self.cache {
908 for (text, embedding) in texts.iter().zip(results.iter()) {
909 cache
910 .put("local", &self.config.model, text, embedding.clone())
911 .await;
912 }
913 }
914
915 Ok(results)
916 }
917
918 #[cfg(not(feature = "local-embeddings"))]
919 {
920 Err(anyhow!(
921 "Local embeddings require 'local-embeddings' feature"
922 ))
923 }
924 }
925
926 fn dimension(&self) -> usize {
927 self.config.dimension.unwrap_or(384)
928 }
929
930 fn model_name(&self) -> &str {
931 &self.config.model
932 }
933
934 fn provider(&self) -> &str {
935 "local"
936 }
937}
938
939#[cfg(feature = "local-embeddings")]
941#[allow(dead_code)]
942trait LocalModelBackend: Send + Sync {
943 fn encode(&self, text: &str) -> Result<Vec<f32>>;
944}
945
946pub struct EmbeddingsFactory {
952 cache: Arc<EmbeddingCache>,
953}
954
955impl EmbeddingsFactory {
956 pub fn new() -> Self {
957 Self {
958 cache: Arc::new(EmbeddingCache::default_cache()),
959 }
960 }
961
962 pub fn with_cache(cache: Arc<EmbeddingCache>) -> Self {
963 Self { cache }
964 }
965
966 pub fn create(&self, config: EmbeddingsConfig) -> Result<Box<dyn EmbeddingModel>> {
968 match config.provider {
969 EmbeddingProvider::OpenAI => Ok(Box::new(OpenAIEmbeddings::with_cache(
970 config,
971 self.cache.clone(),
972 )?)),
973 EmbeddingProvider::HuggingFace => Ok(Box::new(HuggingFaceEmbeddings::with_cache(
974 config,
975 self.cache.clone(),
976 )?)),
977 EmbeddingProvider::Cohere => Ok(Box::new(CohereEmbeddings::with_cache(
978 config,
979 self.cache.clone(),
980 )?)),
981 EmbeddingProvider::Local => Ok(Box::new(LocalEmbeddings::with_cache(
982 config,
983 self.cache.clone(),
984 )?)),
985 EmbeddingProvider::Mock => {
986 let dimension = config.dimension.unwrap_or(DEFAULT_EMBEDDING_DIMENSION);
987 #[cfg(any(feature = "mock", test))]
988 {
989 Ok(Box::new(MockEmbeddingModel::with_name(
990 dimension,
991 &config.model,
992 )))
993 }
994 #[cfg(not(any(feature = "mock", test)))]
995 {
996 let local_config = EmbeddingsConfig::local(&config.model, Some(dimension));
998 Ok(Box::new(LocalEmbeddings::new(local_config)?))
999 }
1000 }
1001 }
1002 }
1003
1004 pub fn create_safe(&self, config: EmbeddingsConfig) -> Box<dyn EmbeddingModel> {
1009 if config.is_valid() {
1010 self.create(config)
1011 .unwrap_or_else(|_| self.create_mock_default())
1012 } else {
1013 self.create_mock_default()
1014 }
1015 }
1016
1017 fn create_mock_default(&self) -> Box<dyn EmbeddingModel> {
1019 #[cfg(any(feature = "mock", test))]
1020 {
1021 Box::new(MockEmbeddingModel::new(DEFAULT_EMBEDDING_DIMENSION))
1022 }
1023 #[cfg(not(any(feature = "mock", test)))]
1024 {
1025 let config = EmbeddingsConfig::local("fallback", Some(DEFAULT_EMBEDDING_DIMENSION));
1027 Box::new(LocalEmbeddings::new(config).expect("Local embeddings should always work"))
1028 }
1029 }
1030
1031 pub fn openai(&self) -> Result<Box<dyn EmbeddingModel>> {
1033 let config = EmbeddingsConfig::openai_from_env()?;
1034 self.create(config)
1035 }
1036
1037 pub fn huggingface(&self) -> Result<Box<dyn EmbeddingModel>> {
1039 let config = EmbeddingsConfig::huggingface_from_env()?;
1040 self.create(config)
1041 }
1042
1043 pub fn cohere(&self) -> Result<Box<dyn EmbeddingModel>> {
1045 let config = EmbeddingsConfig::cohere_from_env()?;
1046 self.create(config)
1047 }
1048
1049 pub fn local(&self, model: &str, dimension: Option<usize>) -> Result<Box<dyn EmbeddingModel>> {
1051 let config = EmbeddingsConfig::local(model, dimension);
1052 self.create(config)
1053 }
1054
1055 #[cfg(any(feature = "mock", test))]
1060 pub fn mock(&self, dimension: usize) -> Box<dyn EmbeddingModel> {
1061 Box::new(MockEmbeddingModel::new(dimension))
1062 }
1063
1064 pub fn cache(&self) -> Arc<EmbeddingCache> {
1066 self.cache.clone()
1067 }
1068}
1069
1070impl Default for EmbeddingsFactory {
1071 fn default() -> Self {
1072 Self::new()
1073 }
1074}
1075
1076#[cfg(any(feature = "mock", test))]
1087pub struct MockEmbeddingModel {
1088 dimension: usize,
1089 model_name: String,
1090}
1091
1092#[cfg(any(feature = "mock", test))]
1093impl MockEmbeddingModel {
1094 pub fn new(dimension: usize) -> Self {
1096 Self {
1097 dimension,
1098 model_name: "mock-embedding".to_string(),
1099 }
1100 }
1101
1102 pub fn with_name(dimension: usize, model_name: impl Into<String>) -> Self {
1104 Self {
1105 dimension,
1106 model_name: model_name.into(),
1107 }
1108 }
1109}
1110
1111#[cfg(any(feature = "mock", test))]
1112#[async_trait]
1113impl EmbeddingModel for MockEmbeddingModel {
1114 async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
1115 Ok(vec![0.0; self.dimension])
1116 }
1117
1118 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1119 Ok(texts.iter().map(|_| vec![0.0; self.dimension]).collect())
1120 }
1121
1122 fn dimension(&self) -> usize {
1123 self.dimension
1124 }
1125
1126 fn model_name(&self) -> &str {
1127 &self.model_name
1128 }
1129
1130 fn provider(&self) -> &str {
1131 "mock"
1132 }
1133}
1134
1135pub type Embeddings = OpenAIEmbeddings;
1143
1144#[cfg(test)]
1149mod tests {
1150 use super::*;
1151
1152 #[tokio::test]
1157 async fn test_cache_basic_operations() {
1158 let cache = EmbeddingCache::new(100, 3600);
1159
1160 let embedding = vec![0.1f32, 0.2, 0.3];
1162 cache
1163 .put("openai", "test-model", "hello", embedding.clone())
1164 .await;
1165
1166 let cached = cache.get("openai", "test-model", "hello").await;
1167 assert!(cached.is_some());
1168 assert_eq!(cached.unwrap(), embedding);
1169
1170 let not_cached = cache.get("openai", "test-model", "not-exists").await;
1172 assert!(not_cached.is_none());
1173 }
1174
1175 #[tokio::test]
1176 async fn test_cache_batch_operations() {
1177 let cache = EmbeddingCache::new(100, 3600);
1178
1179 let texts: Vec<String> = vec!["a".to_string(), "b".to_string(), "c".to_string()];
1180 let embeddings: Vec<Vec<f32>> = texts.iter().map(|t| vec![t.len() as f32]).collect();
1181
1182 for (text, emb) in texts.iter().zip(embeddings.iter()) {
1183 cache.put("test", "model", text, emb.clone()).await;
1184 }
1185
1186 let cached = cache.get_batch("test", "model", &texts).await;
1187 assert!(cached.iter().all(|c| c.is_some()));
1188 }
1189
1190 #[tokio::test]
1191 async fn test_cache_stats() {
1192 let cache = EmbeddingCache::new(100, 3600);
1193
1194 cache.put("test", "model", "a", vec![1.0f32]).await;
1195 cache.put("test", "model", "b", vec![2.0]).await;
1196
1197 let _ = cache.get("test", "model", "a").await;
1198 let _ = cache.get("test", "model", "a").await;
1199
1200 let stats = cache.stats().await;
1201 assert_eq!(stats.total_entries, 2);
1202 assert_eq!(stats.total_access, 2);
1203 }
1204
1205 #[test]
1210 fn test_config_openai_from_env() {
1211 std::env::set_var("OPENAI_API_KEY", "test_key");
1212 std::env::remove_var("OPENAI_BASE_URL");
1213 std::env::remove_var("OPENAI_EMBEDDING_MODEL");
1214
1215 let config = EmbeddingsConfig::openai_from_env().unwrap();
1216 assert_eq!(config.api_key, "test_key");
1217 assert_eq!(config.model, DEFAULT_EMBEDDING_MODEL);
1218
1219 std::env::remove_var("OPENAI_API_KEY");
1220 }
1221
1222 #[test]
1223 fn test_config_huggingface_from_env() {
1224 std::env::set_var("HUGGINGFACE_API_KEY", "hf_test");
1225 std::env::remove_var("HUGGINGFACE_EMBEDDING_MODEL");
1226
1227 let config = EmbeddingsConfig::huggingface_from_env().unwrap();
1228 assert_eq!(config.api_key, "hf_test");
1229 assert!(config.model.contains("sentence-transformers"));
1230
1231 std::env::remove_var("HUGGINGFACE_API_KEY");
1232 }
1233
1234 #[test]
1235 fn test_config_cohere_from_env() {
1236 std::env::set_var("COHERE_API_KEY", "cohere_test");
1237 std::env::remove_var("COHERE_EMBEDDING_MODEL");
1238
1239 let config = EmbeddingsConfig::cohere_from_env().unwrap();
1240 assert_eq!(config.api_key, "cohere_test");
1241 assert!(config.model.starts_with("embed-"));
1242
1243 std::env::remove_var("COHERE_API_KEY");
1244 }
1245
1246 #[test]
1247 fn test_config_local() {
1248 let config = EmbeddingsConfig::local("all-MiniLM-L6-v2", Some(384));
1249 assert_eq!(config.provider, EmbeddingProvider::Local);
1250 assert!(config.api_key.is_empty());
1251 assert!(config.is_valid()); }
1253
1254 #[test]
1259 fn test_openai_dimension() {
1260 let config = EmbeddingsConfig {
1261 provider: EmbeddingProvider::OpenAI,
1262 api_key: "test".to_string(),
1263 base_url: None,
1264 model: "text-embedding-ada-002".to_string(),
1265 dimension: None,
1266 };
1267 let embeddings = OpenAIEmbeddings::new(config).unwrap();
1268 assert_eq!(embeddings.dimension(), 1536);
1269
1270 let config = EmbeddingsConfig {
1271 provider: EmbeddingProvider::OpenAI,
1272 api_key: "test".to_string(),
1273 base_url: None,
1274 model: "text-embedding-3-large".to_string(),
1275 dimension: None,
1276 };
1277 let embeddings = OpenAIEmbeddings::new(config).unwrap();
1278 assert_eq!(embeddings.dimension(), 3072);
1279 }
1280
1281 #[test]
1282 fn test_huggingface_dimension() {
1283 let config = EmbeddingsConfig {
1284 provider: EmbeddingProvider::HuggingFace,
1285 api_key: "test".to_string(),
1286 base_url: None,
1287 model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
1288 dimension: None,
1289 };
1290 let embeddings = HuggingFaceEmbeddings::new(config).unwrap();
1291 assert_eq!(embeddings.dimension(), 384);
1292 }
1293
1294 #[test]
1295 fn test_cohere_dimension() {
1296 let config = EmbeddingsConfig {
1297 provider: EmbeddingProvider::Cohere,
1298 api_key: "test".to_string(),
1299 base_url: None,
1300 model: "embed-english-v3.0".to_string(),
1301 dimension: None,
1302 };
1303 let embeddings = CohereEmbeddings::new(config).unwrap();
1304 assert_eq!(embeddings.dimension(), 1024);
1305 }
1306
1307 #[test]
1312 fn test_factory_create_openai() {
1313 std::env::set_var("OPENAI_API_KEY", "test_key");
1314
1315 let factory = EmbeddingsFactory::new();
1316 let model = factory.openai().unwrap();
1317 assert_eq!(model.provider(), "openai");
1318
1319 std::env::remove_var("OPENAI_API_KEY");
1320 }
1321
1322 #[test]
1323 fn test_factory_create_local() {
1324 let factory = EmbeddingsFactory::new();
1325 let model = factory.local("test-model", Some(384)).unwrap();
1326 assert_eq!(model.provider(), "local");
1327 assert_eq!(model.dimension(), 384);
1328 }
1329
1330 #[test]
1331 fn test_factory_create_mock() {
1332 let factory = EmbeddingsFactory::new();
1333 let model = factory.mock(512);
1334 assert_eq!(model.provider(), "mock");
1335 assert_eq!(model.dimension(), 512);
1336 }
1337
1338 #[test]
1339 fn test_factory_create_safe_with_invalid_config() {
1340 let factory = EmbeddingsFactory::new();
1341 let config = EmbeddingsConfig {
1343 provider: EmbeddingProvider::OpenAI,
1344 api_key: String::new(),
1345 base_url: None,
1346 model: "test".to_string(),
1347 dimension: None,
1348 };
1349 let model = factory.create_safe(config);
1350 assert_eq!(model.provider(), "mock");
1352 }
1353
1354 #[test]
1355 fn test_factory_create_safe_with_valid_config() {
1356 std::env::set_var("OPENAI_API_KEY", "test_key");
1357 let factory = EmbeddingsFactory::new();
1358 let config = EmbeddingsConfig::openai_from_env().unwrap();
1359 let model = factory.create_safe(config);
1360 assert_eq!(model.provider(), "openai");
1361 std::env::remove_var("OPENAI_API_KEY");
1362 }
1363
1364 #[test]
1369 fn test_config_default_is_safe() {
1370 let config = EmbeddingsConfig::default();
1371 assert_eq!(config.provider, EmbeddingProvider::Mock);
1373 assert!(config.is_valid());
1375 }
1376
1377 #[test]
1378 fn test_provider_mock_is_valid() {
1379 let config = EmbeddingsConfig {
1380 provider: EmbeddingProvider::Mock,
1381 api_key: String::new(),
1382 base_url: None,
1383 model: "mock-test".to_string(),
1384 dimension: Some(256),
1385 };
1386 assert!(config.is_valid());
1387 }
1388
1389 #[test]
1390 fn test_embeddings_factory_mock_default_dimension() {
1391 let factory = EmbeddingsFactory::new();
1392 let model = factory.mock(DEFAULT_EMBEDDING_DIMENSION);
1393 assert_eq!(model.dimension(), DEFAULT_EMBEDDING_DIMENSION);
1394 }
1395
1396 #[test]
1401 fn test_backward_compatible_embeddings() {
1402 std::env::set_var("OPENAI_API_KEY", "test_key");
1403
1404 let config = EmbeddingsConfig::openai_from_env().unwrap();
1405 let embeddings = Embeddings::new(config).unwrap();
1406 assert_eq!(embeddings.provider(), "openai");
1407
1408 std::env::remove_var("OPENAI_API_KEY");
1409 }
1410}