Skip to main content

sh_layer1/
embeddings.rs

1//! 嵌入模型模块
2//!
3//! 文本嵌入、批量处理、缓存。
4//!
5//! 支持多种嵌入模型提供商:
6//! - OpenAI Embeddings API
7//! - HuggingFace Inference API
8//! - Cohere Embed API
9//! - 本地 SentenceTransformers 模型
10
11use anyhow::{anyhow, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18use tokio::sync::RwLock;
19
20// ============================================================================
21// 常量定义
22// ============================================================================
23
24/// 默认嵌入模型
25pub const DEFAULT_EMBEDDING_MODEL: &str = "text-embedding-ada-002";
26
27/// 默认嵌入维度
28pub const DEFAULT_EMBEDDING_DIMENSION: usize = 1536;
29
30/// 缓存默认 TTL(秒)
31pub const DEFAULT_CACHE_TTL_SECS: u64 = 3600;
32
33/// 缓存默认最大条目数
34pub const DEFAULT_CACHE_MAX_ENTRIES: usize = 10000;
35
36// ============================================================================
37// 统一 EmbeddingModel Trait
38// ============================================================================
39
40/// 嵌入模型统一接口
41///
42/// 所有嵌入模型必须实现此 trait。
43#[async_trait]
44pub trait EmbeddingModel: Send + Sync {
45    /// 生成单个文本的嵌入向量
46    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
47
48    /// 批量生成嵌入向量
49    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
50
51    /// 获取向量维度
52    fn dimension(&self) -> usize;
53
54    /// 获取模型名称
55    fn model_name(&self) -> &str;
56
57    /// 获取提供商名称
58    fn provider(&self) -> &str;
59}
60
61// ============================================================================
62// 嵌入缓存
63// ============================================================================
64
65/// 缓存条目
66#[derive(Debug, Clone)]
67struct CacheEntry {
68    /// 嵌入向量
69    embedding: Vec<f32>,
70    /// 创建时间
71    created_at: Instant,
72    /// 访问计数
73    access_count: usize,
74}
75
76/// 嵌入缓存
77///
78/// 使用 LRU 策略管理缓存条目。
79#[derive(Debug)]
80pub struct EmbeddingCache {
81    /// 缓存存储
82    store: RwLock<HashMap<String, CacheEntry>>,
83    /// 最大条目数
84    max_entries: usize,
85    /// TTL(秒)
86    ttl_secs: u64,
87}
88
89impl EmbeddingCache {
90    /// 创建新的缓存实例
91    pub fn new(max_entries: usize, ttl_secs: u64) -> Self {
92        Self {
93            store: RwLock::new(HashMap::new()),
94            max_entries,
95            ttl_secs,
96        }
97    }
98
99    /// 使用默认配置创建缓存
100    pub fn default_cache() -> Self {
101        Self::new(DEFAULT_CACHE_MAX_ENTRIES, DEFAULT_CACHE_TTL_SECS)
102    }
103
104    /// 生成缓存键
105    fn cache_key(provider: &str, model: &str, text: &str) -> String {
106        use std::collections::hash_map::DefaultHasher;
107        use std::hash::{Hash, Hasher};
108
109        let mut hasher = DefaultHasher::new();
110        provider.hash(&mut hasher);
111        model.hash(&mut hasher);
112        text.hash(&mut hasher);
113        format!("{}:{}:{:016x}", provider, model, hasher.finish())
114    }
115
116    /// 获取缓存的嵌入向量
117    pub async fn get(&self, provider: &str, model: &str, text: &str) -> Option<Vec<f32>> {
118        let key = Self::cache_key(provider, model, text);
119        let mut store = self.store.write().await;
120
121        if let Some(entry) = store.get_mut(&key) {
122            // 检查是否过期
123            if entry.created_at.elapsed() > Duration::from_secs(self.ttl_secs) {
124                store.remove(&key);
125                return None;
126            }
127
128            entry.access_count += 1;
129            return Some(entry.embedding.clone());
130        }
131
132        None
133    }
134
135    /// 存储嵌入向量到缓存
136    pub async fn put(&self, provider: &str, model: &str, text: &str, embedding: Vec<f32>) {
137        let key = Self::cache_key(provider, model, text);
138        let mut store = self.store.write().await;
139
140        // 如果达到最大条目数,移除最少访问的条目
141        if store.len() >= self.max_entries {
142            if let Some((lru_key, _)) = store
143                .iter()
144                .min_by_key(|(_, e)| e.access_count)
145                .map(|(k, v)| (k.clone(), v.access_count))
146            {
147                store.remove(&lru_key);
148            }
149        }
150
151        store.insert(
152            key,
153            CacheEntry {
154                embedding,
155                created_at: Instant::now(),
156                access_count: 0,
157            },
158        );
159    }
160
161    /// 批量获取缓存的嵌入向量
162    pub async fn get_batch(
163        &self,
164        provider: &str,
165        model: &str,
166        texts: &[String],
167    ) -> Vec<Option<Vec<f32>>> {
168        let mut results = Vec::with_capacity(texts.len());
169        let mut store = self.store.write().await;
170
171        for text in texts {
172            let key = Self::cache_key(provider, model, text);
173
174            if let Some(entry) = store.get_mut(&key) {
175                if entry.created_at.elapsed() > Duration::from_secs(self.ttl_secs) {
176                    store.remove(&key);
177                    results.push(None);
178                } else {
179                    entry.access_count += 1;
180                    results.push(Some(entry.embedding.clone()));
181                }
182            } else {
183                results.push(None);
184            }
185        }
186
187        results
188    }
189
190    /// 清空缓存
191    pub async fn clear(&self) {
192        let mut store = self.store.write().await;
193        store.clear();
194    }
195
196    /// 获取缓存统计信息
197    pub async fn stats(&self) -> CacheStats {
198        let store = self.store.read().await;
199        let total_entries = store.len();
200        let total_access: usize = store.values().map(|e| e.access_count).sum();
201
202        CacheStats {
203            total_entries,
204            total_access,
205            max_entries: self.max_entries,
206            ttl_secs: self.ttl_secs,
207        }
208    }
209}
210
211/// 缓存统计信息
212#[derive(Debug, Clone)]
213pub struct CacheStats {
214    pub total_entries: usize,
215    pub total_access: usize,
216    pub max_entries: usize,
217    pub ttl_secs: u64,
218}
219
220// ============================================================================
221// 模型配置
222// ============================================================================
223
224/// 嵌入模型提供商类型
225#[derive(Debug, Clone, PartialEq, Eq)]
226pub enum EmbeddingProvider {
227    OpenAI,
228    HuggingFace,
229    Cohere,
230    Local,
231    /// Mock 提供商,用于测试和安全默认值
232    Mock,
233}
234
235impl EmbeddingProvider {
236    pub fn as_str(&self) -> &'static str {
237        match self {
238            Self::OpenAI => "openai",
239            Self::HuggingFace => "huggingface",
240            Self::Cohere => "cohere",
241            Self::Local => "local",
242            Self::Mock => "mock",
243        }
244    }
245}
246
247/// 嵌入模型配置
248#[derive(Debug, Clone)]
249pub struct EmbeddingsConfig {
250    /// 提供商类型
251    pub provider: EmbeddingProvider,
252    /// API 密钥(本地模型可为空)
253    pub api_key: String,
254    /// API 基础 URL(可选)
255    pub base_url: Option<String>,
256    /// 模型名称
257    pub model: String,
258    /// 向量维度(可选,用于本地模型)
259    pub dimension: Option<usize>,
260}
261
262impl Default for EmbeddingsConfig {
263    fn default() -> Self {
264        // 安全默认值:使用 Mock 提供商
265        // 这样可以避免在未配置环境下意外调用外部 API
266        Self {
267            provider: EmbeddingProvider::Mock,
268            api_key: String::new(),
269            base_url: None,
270            model: "mock-embedding".to_string(),
271            dimension: Some(DEFAULT_EMBEDDING_DIMENSION),
272        }
273    }
274}
275
276impl EmbeddingsConfig {
277    /// 从环境变量创建 OpenAI 配置
278    pub fn openai_from_env() -> Result<Self> {
279        let api_key = std::env::var("OPENAI_API_KEY")
280            .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
281
282        let base_url = std::env::var("OPENAI_BASE_URL")
283            .ok()
284            .or_else(|| Some("https://api.openai.com/v1".to_string()));
285
286        let model = std::env::var("OPENAI_EMBEDDING_MODEL")
287            .unwrap_or_else(|_| DEFAULT_EMBEDDING_MODEL.to_string());
288
289        Ok(Self {
290            provider: EmbeddingProvider::OpenAI,
291            api_key,
292            base_url,
293            model,
294            dimension: None,
295        })
296    }
297
298    /// 从环境变量创建 HuggingFace 配置
299    pub fn huggingface_from_env() -> Result<Self> {
300        let api_key = std::env::var("HUGGINGFACE_API_KEY")
301            .map_err(|_| anyhow!("HUGGINGFACE_API_KEY environment variable not set"))?;
302
303        let model = std::env::var("HUGGINGFACE_EMBEDDING_MODEL")
304            .unwrap_or_else(|_| "sentence-transformers/all-MiniLM-L6-v2".to_string());
305
306        Ok(Self {
307            provider: EmbeddingProvider::HuggingFace,
308            api_key,
309            base_url: Some(
310                "https://api-inference.huggingface.co/pipeline/feature-extraction".to_string(),
311            ),
312            model,
313            dimension: None,
314        })
315    }
316
317    /// 从环境变量创建 Cohere 配置
318    pub fn cohere_from_env() -> Result<Self> {
319        let api_key = std::env::var("COHERE_API_KEY")
320            .map_err(|_| anyhow!("COHERE_API_KEY environment variable not set"))?;
321
322        let model = std::env::var("COHERE_EMBEDDING_MODEL")
323            .unwrap_or_else(|_| "embed-english-v3.0".to_string());
324
325        Ok(Self {
326            provider: EmbeddingProvider::Cohere,
327            api_key,
328            base_url: Some("https://api.cohere.ai/v1".to_string()),
329            model,
330            dimension: None,
331        })
332    }
333
334    /// 创建本地模型配置
335    pub fn local(model: impl Into<String>, dimension: Option<usize>) -> Self {
336        Self {
337            provider: EmbeddingProvider::Local,
338            api_key: String::new(),
339            base_url: None,
340            model: model.into(),
341            dimension,
342        }
343    }
344
345    /// 检查配置是否有效(本地模型和 Mock 不需要 API key)
346    pub fn is_valid(&self) -> bool {
347        matches!(
348            self.provider,
349            EmbeddingProvider::Local | EmbeddingProvider::Mock
350        ) || !self.api_key.is_empty()
351    }
352}
353
354// ============================================================================
355// OpenAI 实现
356// ============================================================================
357
358/// OpenAI 嵌入模型
359#[derive(Debug)]
360pub struct OpenAIEmbeddings {
361    client: Client,
362    config: EmbeddingsConfig,
363    cache: Option<Arc<EmbeddingCache>>,
364}
365
366impl OpenAIEmbeddings {
367    pub fn new(config: EmbeddingsConfig) -> Result<Self> {
368        if !config.is_valid() {
369            return Err(anyhow!("OpenAI Embeddings API not configured"));
370        }
371
372        Ok(Self {
373            client: Client::new(),
374            config,
375            cache: None,
376        })
377    }
378
379    pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
380        let mut embeddings = Self::new(config)?;
381        embeddings.cache = Some(cache);
382        Ok(embeddings)
383    }
384
385    fn base_url(&self) -> &str {
386        self.config
387            .base_url
388            .as_deref()
389            .unwrap_or("https://api.openai.com/v1")
390    }
391}
392
393#[async_trait]
394impl EmbeddingModel for OpenAIEmbeddings {
395    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
396        let embeddings = self.embed_batch(&[text.to_string()]).await?;
397        embeddings
398            .into_iter()
399            .next()
400            .ok_or_else(|| anyhow!("No embedding returned"))
401    }
402
403    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
404        if texts.is_empty() {
405            return Ok(Vec::new());
406        }
407
408        // 检查缓存
409        if let Some(cache) = &self.cache {
410            let cached = cache.get_batch("openai", &self.config.model, texts).await;
411            let all_cached = cached.iter().all(|c| c.is_some());
412            if all_cached {
413                return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
414            }
415        }
416
417        let url = format!("{}/embeddings", self.base_url());
418
419        let request_body = OpenAiEmbeddingRequest {
420            model: self.config.model.clone(),
421            input: texts.to_vec(),
422            encoding_format: Some("float".to_string()),
423        };
424
425        tracing::debug!("Sending OpenAI embedding request for {} texts", texts.len());
426
427        let response = self
428            .client
429            .post(&url)
430            .header("Authorization", format!("Bearer {}", self.config.api_key))
431            .header("Content-Type", "application/json")
432            .json(&request_body)
433            .send()
434            .await?;
435
436        let status = response.status();
437        let response_text = response.text().await?;
438
439        if !status.is_success() {
440            tracing::error!("OpenAI Embedding API error: {} - {}", status, response_text);
441            return Err(anyhow!(
442                "OpenAI Embedding API request failed with status {}: {}",
443                status,
444                response_text
445            ));
446        }
447
448        let response_body: OpenAiEmbeddingResponse =
449            serde_json::from_str(&response_text).map_err(|e| {
450                anyhow!(
451                    "Failed to parse OpenAI embedding response: {} - {}",
452                    e,
453                    response_text
454                )
455            })?;
456
457        // 按 index 排序并提取向量
458        let mut embeddings: Vec<(usize, Vec<f32>)> = response_body
459            .data
460            .into_iter()
461            .map(|item| (item.index, item.embedding))
462            .collect();
463        embeddings.sort_by_key(|(idx, _)| *idx);
464        let result: Vec<Vec<f32>> = embeddings.into_iter().map(|(_, emb)| emb).collect();
465
466        // 存入缓存
467        if let Some(cache) = &self.cache {
468            for (text, embedding) in texts.iter().zip(result.iter()) {
469                cache
470                    .put("openai", &self.config.model, text, embedding.clone())
471                    .await;
472            }
473        }
474
475        Ok(result)
476    }
477
478    fn dimension(&self) -> usize {
479        match self.config.model.as_str() {
480            "text-embedding-ada-002" => 1536,
481            "text-embedding-3-small" => 1536,
482            "text-embedding-3-large" => 3072,
483            _ => DEFAULT_EMBEDDING_DIMENSION,
484        }
485    }
486
487    fn model_name(&self) -> &str {
488        &self.config.model
489    }
490
491    fn provider(&self) -> &str {
492        "openai"
493    }
494}
495
496#[derive(Serialize)]
497struct OpenAiEmbeddingRequest {
498    model: String,
499    input: Vec<String>,
500    #[serde(skip_serializing_if = "Option::is_none")]
501    encoding_format: Option<String>,
502}
503
504#[derive(Deserialize)]
505struct OpenAiEmbeddingResponse {
506    data: Vec<OpenAiEmbeddingData>,
507    #[allow(dead_code)]
508    model: String,
509    #[allow(dead_code)]
510    usage: OpenAiEmbeddingUsage,
511}
512
513#[derive(Deserialize)]
514struct OpenAiEmbeddingData {
515    embedding: Vec<f32>,
516    index: usize,
517    #[allow(dead_code)]
518    object: String,
519}
520
521#[derive(Deserialize)]
522#[allow(dead_code)]
523struct OpenAiEmbeddingUsage {
524    prompt_tokens: u32,
525    total_tokens: u32,
526}
527
528// ============================================================================
529// HuggingFace 实现
530// ============================================================================
531
532/// HuggingFace 嵌入模型
533#[derive(Debug)]
534pub struct HuggingFaceEmbeddings {
535    client: Client,
536    config: EmbeddingsConfig,
537    cache: Option<Arc<EmbeddingCache>>,
538}
539
540impl HuggingFaceEmbeddings {
541    pub fn new(config: EmbeddingsConfig) -> Result<Self> {
542        if !config.is_valid() {
543            return Err(anyhow!("HuggingFace API not configured"));
544        }
545
546        Ok(Self {
547            client: Client::new(),
548            config,
549            cache: None,
550        })
551    }
552
553    pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
554        let mut embeddings = Self::new(config)?;
555        embeddings.cache = Some(cache);
556        Ok(embeddings)
557    }
558}
559
560#[async_trait]
561impl EmbeddingModel for HuggingFaceEmbeddings {
562    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
563        // HuggingFace API 返回格式取决于模型,通常需要单独调用
564        let embeddings = self.embed_batch(&[text.to_string()]).await?;
565        embeddings
566            .into_iter()
567            .next()
568            .ok_or_else(|| anyhow!("No embedding returned from HuggingFace"))
569    }
570
571    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
572        if texts.is_empty() {
573            return Ok(Vec::new());
574        }
575
576        // 检查缓存
577        if let Some(cache) = &self.cache {
578            let cached = cache
579                .get_batch("huggingface", &self.config.model, texts)
580                .await;
581            let all_cached = cached.iter().all(|c| c.is_some());
582            if all_cached {
583                return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
584            }
585        }
586
587        let url = format!(
588            "https://api-inference.huggingface.co/pipeline/feature-extraction/{}",
589            self.config.model
590        );
591
592        tracing::debug!(
593            "Sending HuggingFace embedding request for {} texts",
594            texts.len()
595        );
596
597        let response = self
598            .client
599            .post(&url)
600            .header("Authorization", format!("Bearer {}", self.config.api_key))
601            .header("Content-Type", "application/json")
602            .json(&serde_json::json!({ "inputs": texts }))
603            .send()
604            .await?;
605
606        let status = response.status();
607        let response_text = response.text().await?;
608
609        if !status.is_success() {
610            tracing::error!("HuggingFace API error: {} - {}", status, response_text);
611            return Err(anyhow!(
612                "HuggingFace API request failed with status {}: {}",
613                status,
614                response_text
615            ));
616        }
617
618        // HuggingFace 返回格式: [[f32, f32, ...], ...] 或 [[f32], [f32], ...]
619        let embeddings: Vec<Vec<f32>> = serde_json::from_str(&response_text).map_err(|e| {
620            anyhow!(
621                "Failed to parse HuggingFace response: {} - {}",
622                e,
623                response_text
624            )
625        })?;
626
627        // 存入缓存
628        if let Some(cache) = &self.cache {
629            for (text, embedding) in texts.iter().zip(embeddings.iter()) {
630                cache
631                    .put("huggingface", &self.config.model, text, embedding.clone())
632                    .await;
633            }
634        }
635
636        Ok(embeddings)
637    }
638
639    fn dimension(&self) -> usize {
640        // 常见模型的维度
641        match self.config.model.as_str() {
642            "sentence-transformers/all-MiniLM-L6-v2" => 384,
643            "sentence-transformers/all-mpnet-base-v2" => 768,
644            "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" => 384,
645            _ => self.config.dimension.unwrap_or(768),
646        }
647    }
648
649    fn model_name(&self) -> &str {
650        &self.config.model
651    }
652
653    fn provider(&self) -> &str {
654        "huggingface"
655    }
656}
657
658// ============================================================================
659// Cohere 实现
660// ============================================================================
661
662/// Cohere 嵌入模型
663#[derive(Debug)]
664pub struct CohereEmbeddings {
665    client: Client,
666    config: EmbeddingsConfig,
667    cache: Option<Arc<EmbeddingCache>>,
668}
669
670impl CohereEmbeddings {
671    pub fn new(config: EmbeddingsConfig) -> Result<Self> {
672        if !config.is_valid() {
673            return Err(anyhow!("Cohere API not configured"));
674        }
675
676        Ok(Self {
677            client: Client::new(),
678            config,
679            cache: None,
680        })
681    }
682
683    pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
684        let mut embeddings = Self::new(config)?;
685        embeddings.cache = Some(cache);
686        Ok(embeddings)
687    }
688}
689
690#[async_trait]
691impl EmbeddingModel for CohereEmbeddings {
692    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
693        let embeddings = self.embed_batch(&[text.to_string()]).await?;
694        embeddings
695            .into_iter()
696            .next()
697            .ok_or_else(|| anyhow!("No embedding returned from Cohere"))
698    }
699
700    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
701        if texts.is_empty() {
702            return Ok(Vec::new());
703        }
704
705        // 检查缓存
706        if let Some(cache) = &self.cache {
707            let cached = cache.get_batch("cohere", &self.config.model, texts).await;
708            let all_cached = cached.iter().all(|c| c.is_some());
709            if all_cached {
710                return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
711            }
712        }
713
714        let url = "https://api.cohere.ai/v1/embed";
715
716        let request_body = CohereEmbeddingRequest {
717            model: self.config.model.clone(),
718            texts: texts.to_vec(),
719            input_type: "search_document",
720            embedding_types: Some(vec!["float".to_string()]),
721        };
722
723        tracing::debug!("Sending Cohere embedding request for {} texts", texts.len());
724
725        let response = self
726            .client
727            .post(url)
728            .header("Authorization", format!("Bearer {}", self.config.api_key))
729            .header("Content-Type", "application/json")
730            .json(&request_body)
731            .send()
732            .await?;
733
734        let status = response.status();
735        let response_text = response.text().await?;
736
737        if !status.is_success() {
738            tracing::error!("Cohere API error: {} - {}", status, response_text);
739            return Err(anyhow!(
740                "Cohere API request failed with status {}: {}",
741                status,
742                response_text
743            ));
744        }
745
746        let response_body: CohereEmbeddingResponse = serde_json::from_str(&response_text)
747            .map_err(|e| anyhow!("Failed to parse Cohere response: {} - {}", e, response_text))?;
748
749        let result = response_body.embeddings.float;
750
751        // 存入缓存
752        if let Some(cache) = &self.cache {
753            for (text, embedding) in texts.iter().zip(result.iter()) {
754                cache
755                    .put("cohere", &self.config.model, text, embedding.clone())
756                    .await;
757            }
758        }
759
760        Ok(result)
761    }
762
763    fn dimension(&self) -> usize {
764        match self.config.model.as_str() {
765            "embed-english-v3.0" | "embed-english-light-v3.0" => 1024,
766            "embed-multilingual-v3.0" => 1024,
767            "embed-english-v2.0" => 4096,
768            _ => self.config.dimension.unwrap_or(1024),
769        }
770    }
771
772    fn model_name(&self) -> &str {
773        &self.config.model
774    }
775
776    fn provider(&self) -> &str {
777        "cohere"
778    }
779}
780
781#[derive(Serialize)]
782struct CohereEmbeddingRequest {
783    model: String,
784    texts: Vec<String>,
785    input_type: &'static str,
786    #[serde(skip_serializing_if = "Option::is_none")]
787    embedding_types: Option<Vec<String>>,
788}
789
790#[derive(Deserialize)]
791struct CohereEmbeddingResponse {
792    embeddings: CohereEmbeddingsData,
793    #[allow(dead_code)]
794    id: String,
795    #[allow(dead_code)]
796    text_type: String,
797}
798
799#[derive(Deserialize)]
800struct CohereEmbeddingsData {
801    float: Vec<Vec<f32>>,
802}
803
804// ============================================================================
805// 本地模型实现 (SentenceTransformers)
806// ============================================================================
807
808/// 本地 SentenceTransformers 嵌入模型
809///
810/// 注意:此实现需要 `candle` 或 `ort` 特性启用。
811/// 在纯 Rust 环境下,使用占位实现。
812#[derive(Debug)]
813pub struct LocalEmbeddings {
814    config: EmbeddingsConfig,
815    cache: Option<Arc<EmbeddingCache>>,
816    #[cfg(feature = "local-embeddings")]
817    model: Option<std::sync::Mutex<Box<dyn LocalModelBackend>>>,
818}
819
820impl LocalEmbeddings {
821    pub fn new(config: EmbeddingsConfig) -> Result<Self> {
822        Ok(Self {
823            config,
824            cache: None,
825            #[cfg(feature = "local-embeddings")]
826            model: None,
827        })
828    }
829
830    pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
831        let mut embeddings = Self::new(config)?;
832        embeddings.cache = Some(cache);
833        Ok(embeddings)
834    }
835
836    /// 加载本地模型
837    #[cfg(feature = "local-embeddings")]
838    pub fn load_model(&mut self) -> Result<()> {
839        // 使用 candle 或 ort 加载模型
840        // 这是一个占位实现
841        tracing::info!("Loading local embedding model: {}", self.config.model);
842        Ok(())
843    }
844}
845
846#[async_trait]
847impl EmbeddingModel for LocalEmbeddings {
848    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
849        // 检查缓存
850        if let Some(cache) = &self.cache {
851            if let Some(embedding) = cache.get("local", &self.config.model, text).await {
852                return Ok(embedding);
853            }
854        }
855
856        #[cfg(feature = "local-embeddings")]
857        {
858            // 实际实现使用 candle 或 ort
859            // 这里是占位代码
860            let embedding = vec![0.0f32; self.dimension()];
861
862            if let Some(cache) = &self.cache {
863                cache
864                    .put("local", &self.config.model, text, embedding.clone())
865                    .await;
866            }
867
868            Ok(embedding)
869        }
870
871        #[cfg(not(feature = "local-embeddings"))]
872        {
873            Err(anyhow!(
874                "Local embeddings require 'local-embeddings' feature. \
875                 Enable it in Cargo.toml and ensure candle or ort is available."
876            ))
877        }
878    }
879
880    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
881        // 检查缓存
882        if let Some(cache) = &self.cache {
883            let cached = cache.get_batch("local", &self.config.model, texts).await;
884            if cached.iter().all(|c| c.is_some()) {
885                return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
886            }
887        }
888
889        #[cfg(feature = "local-embeddings")]
890        {
891            let mut results = Vec::with_capacity(texts.len());
892            for text in texts {
893                results.push(self.embed(text).await?);
894            }
895
896            // 存入缓存
897            if let Some(cache) = &self.cache {
898                for (text, embedding) in texts.iter().zip(results.iter()) {
899                    cache
900                        .put("local", &self.config.model, text, embedding.clone())
901                        .await;
902                }
903            }
904
905            Ok(results)
906        }
907
908        #[cfg(not(feature = "local-embeddings"))]
909        {
910            Err(anyhow!(
911                "Local embeddings require 'local-embeddings' feature"
912            ))
913        }
914    }
915
916    fn dimension(&self) -> usize {
917        self.config.dimension.unwrap_or(384)
918    }
919
920    fn model_name(&self) -> &str {
921        &self.config.model
922    }
923
924    fn provider(&self) -> &str {
925        "local"
926    }
927}
928
929/// 本地模型后端 trait
930#[cfg(feature = "local-embeddings")]
931trait LocalModelBackend: Send + Sync {
932    fn encode(&self, text: &str) -> Result<Vec<f32>>;
933}
934
935// ============================================================================
936// 统一 Embeddings 工厂
937// ============================================================================
938
939/// 嵌入模型工厂
940pub struct EmbeddingsFactory {
941    cache: Arc<EmbeddingCache>,
942}
943
944impl EmbeddingsFactory {
945    pub fn new() -> Self {
946        Self {
947            cache: Arc::new(EmbeddingCache::default_cache()),
948        }
949    }
950
951    pub fn with_cache(cache: Arc<EmbeddingCache>) -> Self {
952        Self { cache }
953    }
954
955    /// 创建嵌入模型实例
956    pub fn create(&self, config: EmbeddingsConfig) -> Result<Box<dyn EmbeddingModel>> {
957        match config.provider {
958            EmbeddingProvider::OpenAI => Ok(Box::new(OpenAIEmbeddings::with_cache(
959                config,
960                self.cache.clone(),
961            )?)),
962            EmbeddingProvider::HuggingFace => Ok(Box::new(HuggingFaceEmbeddings::with_cache(
963                config,
964                self.cache.clone(),
965            )?)),
966            EmbeddingProvider::Cohere => Ok(Box::new(CohereEmbeddings::with_cache(
967                config,
968                self.cache.clone(),
969            )?)),
970            EmbeddingProvider::Local => Ok(Box::new(LocalEmbeddings::with_cache(
971                config,
972                self.cache.clone(),
973            )?)),
974            EmbeddingProvider::Mock => {
975                let dimension = config.dimension.unwrap_or(DEFAULT_EMBEDDING_DIMENSION);
976                #[cfg(any(feature = "mock", test))]
977                {
978                    Ok(Box::new(MockEmbeddingModel::with_name(
979                        dimension,
980                        &config.model,
981                    )))
982                }
983                #[cfg(not(any(feature = "mock", test)))]
984                {
985                    // 当 mock feature 未启用时,使用 LocalEmbeddings 作为回退
986                    let local_config = EmbeddingsConfig::local(&config.model, Some(dimension));
987                    Ok(Box::new(LocalEmbeddings::new(local_config)?))
988                }
989            }
990        }
991    }
992
993    /// 创建安全的嵌入模型实例
994    ///
995    /// 如果指定的配置无效,自动回退到 Mock 模型。
996    /// 这确保了即使在未配置环境下也能安全返回一个可用实例。
997    pub fn create_safe(&self, config: EmbeddingsConfig) -> Box<dyn EmbeddingModel> {
998        if config.is_valid() {
999            self.create(config)
1000                .unwrap_or_else(|_| self.create_mock_default())
1001        } else {
1002            self.create_mock_default()
1003        }
1004    }
1005
1006    /// 创建默认 Mock 模型
1007    fn create_mock_default(&self) -> Box<dyn EmbeddingModel> {
1008        #[cfg(any(feature = "mock", test))]
1009        {
1010            Box::new(MockEmbeddingModel::new(DEFAULT_EMBEDDING_DIMENSION))
1011        }
1012        #[cfg(not(any(feature = "mock", test)))]
1013        {
1014            // 如果没有 mock feature,使用 LocalEmbeddings 作为安全回退
1015            let config = EmbeddingsConfig::local("fallback", Some(DEFAULT_EMBEDDING_DIMENSION));
1016            Box::new(LocalEmbeddings::new(config).expect("Local embeddings should always work"))
1017        }
1018    }
1019
1020    /// 创建 OpenAI 嵌入模型
1021    pub fn openai(&self) -> Result<Box<dyn EmbeddingModel>> {
1022        let config = EmbeddingsConfig::openai_from_env()?;
1023        self.create(config)
1024    }
1025
1026    /// 创建 HuggingFace 嵌入模型
1027    pub fn huggingface(&self) -> Result<Box<dyn EmbeddingModel>> {
1028        let config = EmbeddingsConfig::huggingface_from_env()?;
1029        self.create(config)
1030    }
1031
1032    /// 创建 Cohere 嵌入模型
1033    pub fn cohere(&self) -> Result<Box<dyn EmbeddingModel>> {
1034        let config = EmbeddingsConfig::cohere_from_env()?;
1035        self.create(config)
1036    }
1037
1038    /// 创建本地嵌入模型
1039    pub fn local(&self, model: &str, dimension: Option<usize>) -> Result<Box<dyn EmbeddingModel>> {
1040        let config = EmbeddingsConfig::local(model, dimension);
1041        self.create(config)
1042    }
1043
1044    /// 创建 Mock 嵌入模型(仅测试/开发使用)
1045    ///
1046    /// **安全默认值**: Mock 模型返回零向量,不调用任何外部 API。
1047    /// 这是在未配置环境下的安全回退选项。
1048    #[cfg(any(feature = "mock", test))]
1049    pub fn mock(&self, dimension: usize) -> Box<dyn EmbeddingModel> {
1050        Box::new(MockEmbeddingModel::new(dimension))
1051    }
1052
1053    /// 获取缓存实例
1054    pub fn cache(&self) -> Arc<EmbeddingCache> {
1055        self.cache.clone()
1056    }
1057}
1058
1059impl Default for EmbeddingsFactory {
1060    fn default() -> Self {
1061        Self::new()
1062    }
1063}
1064
1065// ============================================================================
1066// Mock 嵌入模型(仅测试/开发使用)
1067// ============================================================================
1068
1069/// Mock 嵌入模型
1070///
1071/// 用于测试场景或作为回退,返回固定维度的零向量。
1072///
1073/// **注意**: 此类型仅在启用 `mock` feature 或测试配置下可用。
1074/// 生产代码不应使用此类型。
1075#[cfg(any(feature = "mock", test))]
1076pub struct MockEmbeddingModel {
1077    dimension: usize,
1078    model_name: String,
1079}
1080
1081#[cfg(any(feature = "mock", test))]
1082impl MockEmbeddingModel {
1083    /// 创建新的 Mock 模型
1084    pub fn new(dimension: usize) -> Self {
1085        Self {
1086            dimension,
1087            model_name: "mock-embedding".to_string(),
1088        }
1089    }
1090
1091    /// 使用自定义模型名创建
1092    pub fn with_name(dimension: usize, model_name: impl Into<String>) -> Self {
1093        Self {
1094            dimension,
1095            model_name: model_name.into(),
1096        }
1097    }
1098}
1099
1100#[cfg(any(feature = "mock", test))]
1101#[async_trait]
1102impl EmbeddingModel for MockEmbeddingModel {
1103    async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
1104        Ok(vec![0.0; self.dimension])
1105    }
1106
1107    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1108        Ok(texts.iter().map(|_| vec![0.0; self.dimension]).collect())
1109    }
1110
1111    fn dimension(&self) -> usize {
1112        self.dimension
1113    }
1114
1115    fn model_name(&self) -> &str {
1116        &self.model_name
1117    }
1118
1119    fn provider(&self) -> &str {
1120        "mock"
1121    }
1122}
1123
1124// ============================================================================
1125// 向后兼容:保留原有 Embeddings 类型别名
1126// ============================================================================
1127
1128/// 向后兼容的 Embeddings 类型
1129///
1130/// 默认使用 OpenAI。
1131pub type Embeddings = OpenAIEmbeddings;
1132
1133// ============================================================================
1134// 测试
1135// ============================================================================
1136
1137#[cfg(test)]
1138mod tests {
1139    use super::*;
1140
1141    // ==========================================================================
1142    // 缓存测试
1143    // ==========================================================================
1144
1145    #[tokio::test]
1146    async fn test_cache_basic_operations() {
1147        let cache = EmbeddingCache::new(100, 3600);
1148
1149        // 测试 put 和 get
1150        let embedding = vec![0.1f32, 0.2, 0.3];
1151        cache
1152            .put("openai", "test-model", "hello", embedding.clone())
1153            .await;
1154
1155        let cached = cache.get("openai", "test-model", "hello").await;
1156        assert!(cached.is_some());
1157        assert_eq!(cached.unwrap(), embedding);
1158
1159        // 测试未命中的情况
1160        let not_cached = cache.get("openai", "test-model", "not-exists").await;
1161        assert!(not_cached.is_none());
1162    }
1163
1164    #[tokio::test]
1165    async fn test_cache_batch_operations() {
1166        let cache = EmbeddingCache::new(100, 3600);
1167
1168        let texts: Vec<String> = vec!["a".to_string(), "b".to_string(), "c".to_string()];
1169        let embeddings: Vec<Vec<f32>> = texts.iter().map(|t| vec![t.len() as f32]).collect();
1170
1171        for (text, emb) in texts.iter().zip(embeddings.iter()) {
1172            cache.put("test", "model", text, emb.clone()).await;
1173        }
1174
1175        let cached = cache.get_batch("test", "model", &texts).await;
1176        assert!(cached.iter().all(|c| c.is_some()));
1177    }
1178
1179    #[tokio::test]
1180    async fn test_cache_stats() {
1181        let cache = EmbeddingCache::new(100, 3600);
1182
1183        cache.put("test", "model", "a", vec![1.0f32]).await;
1184        cache.put("test", "model", "b", vec![2.0]).await;
1185
1186        let _ = cache.get("test", "model", "a").await;
1187        let _ = cache.get("test", "model", "a").await;
1188
1189        let stats = cache.stats().await;
1190        assert_eq!(stats.total_entries, 2);
1191        assert_eq!(stats.total_access, 2);
1192    }
1193
1194    // ==========================================================================
1195    // 配置测试
1196    // ==========================================================================
1197
1198    #[test]
1199    fn test_config_openai_from_env() {
1200        std::env::set_var("OPENAI_API_KEY", "test_key");
1201        std::env::remove_var("OPENAI_BASE_URL");
1202        std::env::remove_var("OPENAI_EMBEDDING_MODEL");
1203
1204        let config = EmbeddingsConfig::openai_from_env().unwrap();
1205        assert_eq!(config.api_key, "test_key");
1206        assert_eq!(config.model, DEFAULT_EMBEDDING_MODEL);
1207
1208        std::env::remove_var("OPENAI_API_KEY");
1209    }
1210
1211    #[test]
1212    fn test_config_huggingface_from_env() {
1213        std::env::set_var("HUGGINGFACE_API_KEY", "hf_test");
1214        std::env::remove_var("HUGGINGFACE_EMBEDDING_MODEL");
1215
1216        let config = EmbeddingsConfig::huggingface_from_env().unwrap();
1217        assert_eq!(config.api_key, "hf_test");
1218        assert!(config.model.contains("sentence-transformers"));
1219
1220        std::env::remove_var("HUGGINGFACE_API_KEY");
1221    }
1222
1223    #[test]
1224    fn test_config_cohere_from_env() {
1225        std::env::set_var("COHERE_API_KEY", "cohere_test");
1226        std::env::remove_var("COHERE_EMBEDDING_MODEL");
1227
1228        let config = EmbeddingsConfig::cohere_from_env().unwrap();
1229        assert_eq!(config.api_key, "cohere_test");
1230        assert!(config.model.starts_with("embed-"));
1231
1232        std::env::remove_var("COHERE_API_KEY");
1233    }
1234
1235    #[test]
1236    fn test_config_local() {
1237        let config = EmbeddingsConfig::local("all-MiniLM-L6-v2", Some(384));
1238        assert_eq!(config.provider, EmbeddingProvider::Local);
1239        assert!(config.api_key.is_empty());
1240        assert!(config.is_valid()); // 本地模型不需要 API key
1241    }
1242
1243    // ==========================================================================
1244    // 维度测试
1245    // ==========================================================================
1246
1247    #[test]
1248    fn test_openai_dimension() {
1249        let config = EmbeddingsConfig {
1250            provider: EmbeddingProvider::OpenAI,
1251            api_key: "test".to_string(),
1252            base_url: None,
1253            model: "text-embedding-ada-002".to_string(),
1254            dimension: None,
1255        };
1256        let embeddings = OpenAIEmbeddings::new(config).unwrap();
1257        assert_eq!(embeddings.dimension(), 1536);
1258
1259        let config = EmbeddingsConfig {
1260            provider: EmbeddingProvider::OpenAI,
1261            api_key: "test".to_string(),
1262            base_url: None,
1263            model: "text-embedding-3-large".to_string(),
1264            dimension: None,
1265        };
1266        let embeddings = OpenAIEmbeddings::new(config).unwrap();
1267        assert_eq!(embeddings.dimension(), 3072);
1268    }
1269
1270    #[test]
1271    fn test_huggingface_dimension() {
1272        let config = EmbeddingsConfig {
1273            provider: EmbeddingProvider::HuggingFace,
1274            api_key: "test".to_string(),
1275            base_url: None,
1276            model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
1277            dimension: None,
1278        };
1279        let embeddings = HuggingFaceEmbeddings::new(config).unwrap();
1280        assert_eq!(embeddings.dimension(), 384);
1281    }
1282
1283    #[test]
1284    fn test_cohere_dimension() {
1285        let config = EmbeddingsConfig {
1286            provider: EmbeddingProvider::Cohere,
1287            api_key: "test".to_string(),
1288            base_url: None,
1289            model: "embed-english-v3.0".to_string(),
1290            dimension: None,
1291        };
1292        let embeddings = CohereEmbeddings::new(config).unwrap();
1293        assert_eq!(embeddings.dimension(), 1024);
1294    }
1295
1296    // ==========================================================================
1297    // 工厂测试
1298    // ==========================================================================
1299
1300    #[test]
1301    fn test_factory_create_openai() {
1302        std::env::set_var("OPENAI_API_KEY", "test_key");
1303
1304        let factory = EmbeddingsFactory::new();
1305        let model = factory.openai().unwrap();
1306        assert_eq!(model.provider(), "openai");
1307
1308        std::env::remove_var("OPENAI_API_KEY");
1309    }
1310
1311    #[test]
1312    fn test_factory_create_local() {
1313        let factory = EmbeddingsFactory::new();
1314        let model = factory.local("test-model", Some(384)).unwrap();
1315        assert_eq!(model.provider(), "local");
1316        assert_eq!(model.dimension(), 384);
1317    }
1318
1319    #[test]
1320    fn test_factory_create_mock() {
1321        let factory = EmbeddingsFactory::new();
1322        let model = factory.mock(512);
1323        assert_eq!(model.provider(), "mock");
1324        assert_eq!(model.dimension(), 512);
1325    }
1326
1327    #[test]
1328    fn test_factory_create_safe_with_invalid_config() {
1329        let factory = EmbeddingsFactory::new();
1330        // 使用空 api_key 的 OpenAI 配置是无效的
1331        let config = EmbeddingsConfig {
1332            provider: EmbeddingProvider::OpenAI,
1333            api_key: String::new(),
1334            base_url: None,
1335            model: "test".to_string(),
1336            dimension: None,
1337        };
1338        let model = factory.create_safe(config);
1339        // 应该回退到 mock
1340        assert_eq!(model.provider(), "mock");
1341    }
1342
1343    #[test]
1344    fn test_factory_create_safe_with_valid_config() {
1345        std::env::set_var("OPENAI_API_KEY", "test_key");
1346        let factory = EmbeddingsFactory::new();
1347        let config = EmbeddingsConfig::openai_from_env().unwrap();
1348        let model = factory.create_safe(config);
1349        assert_eq!(model.provider(), "openai");
1350        std::env::remove_var("OPENAI_API_KEY");
1351    }
1352
1353    // ==========================================================================
1354    // 安全默认值测试
1355    // ==========================================================================
1356
1357    #[test]
1358    fn test_config_default_is_safe() {
1359        let config = EmbeddingsConfig::default();
1360        // 默认配置应该使用 Mock 提供商
1361        assert_eq!(config.provider, EmbeddingProvider::Mock);
1362        // Mock 提供商不需要 API key,所以应该有效
1363        assert!(config.is_valid());
1364    }
1365
1366    #[test]
1367    fn test_provider_mock_is_valid() {
1368        let config = EmbeddingsConfig {
1369            provider: EmbeddingProvider::Mock,
1370            api_key: String::new(),
1371            base_url: None,
1372            model: "mock-test".to_string(),
1373            dimension: Some(256),
1374        };
1375        assert!(config.is_valid());
1376    }
1377
1378    #[test]
1379    fn test_embeddings_factory_mock_default_dimension() {
1380        let factory = EmbeddingsFactory::new();
1381        let model = factory.mock(DEFAULT_EMBEDDING_DIMENSION);
1382        assert_eq!(model.dimension(), DEFAULT_EMBEDDING_DIMENSION);
1383    }
1384
1385    // ==========================================================================
1386    // 向后兼容测试
1387    // ==========================================================================
1388
1389    #[test]
1390    fn test_backward_compatible_embeddings() {
1391        std::env::set_var("OPENAI_API_KEY", "test_key");
1392
1393        let config = EmbeddingsConfig::openai_from_env().unwrap();
1394        let embeddings = Embeddings::new(config).unwrap();
1395        assert_eq!(embeddings.provider(), "openai");
1396
1397        std::env::remove_var("OPENAI_API_KEY");
1398    }
1399}