Skip to main content

sh_layer3/
retriever.rs

1//! # Hybrid Retriever
2//!
3//! 混合检索器:结合 BM25 关键词检索和向量相似度检索。
4//!
5//! ## 功能
6//!
7//! - BM25 关键词检索
8//! - 向量相似度检索
9//! - Reciprocal Rank Fusion (RRF) 融合算法
10//! - 可配置的权重支持
11
12use crate::retriever_engine::{Document, EmbeddingModel, RetrievalResult};
13use crate::types::Layer3Result;
14use crate::vector_store::{MetadataFilter, VectorStore};
15use async_trait::async_trait;
16use parking_lot::RwLock;
17use sh_layer2::generate_short_id;
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20use tracing::instrument;
21
22// ============================================================================
23// BM25 Implementation
24// ============================================================================
25
26/// BM25 检索器
27///
28/// 实现经典的 BM25 排序算法,用于关键词检索。
29pub struct BM25Index {
30    /// 文档 ID -> 文档内容
31    documents: Arc<RwLock<HashMap<String, String>>>,
32    /// 文档 ID -> 词项频率
33    term_frequencies: Arc<RwLock<HashMap<String, HashMap<String, usize>>>>,
34    /// 逆文档频率缓存
35    idf_cache: Arc<RwLock<HashMap<String, f64>>>,
36    /// 文档平均长度
37    avg_doc_length: Arc<RwLock<f64>>,
38    /// 总文档数
39    doc_count: Arc<RwLock<usize>>,
40    /// BM25 参数 k1 (词项饱和参数)
41    k1: f64,
42    /// BM25 参数 b (文档长度归一化参数)
43    b: f64,
44}
45
46impl BM25Index {
47    /// 创建新的 BM25 索引
48    pub fn new() -> Self {
49        Self {
50            documents: Arc::new(RwLock::new(HashMap::new())),
51            term_frequencies: Arc::new(RwLock::new(HashMap::new())),
52            idf_cache: Arc::new(RwLock::new(HashMap::new())),
53            avg_doc_length: Arc::new(RwLock::new(0.0)),
54            doc_count: Arc::new(RwLock::new(0)),
55            k1: 1.2,
56            b: 0.75,
57        }
58    }
59
60    /// 使用自定义参数创建 BM25 索引
61    pub fn with_params(k1: f64, b: f64) -> Self {
62        Self {
63            documents: Arc::new(RwLock::new(HashMap::new())),
64            term_frequencies: Arc::new(RwLock::new(HashMap::new())),
65            idf_cache: Arc::new(RwLock::new(HashMap::new())),
66            avg_doc_length: Arc::new(RwLock::new(0.0)),
67            doc_count: Arc::new(RwLock::new(0)),
68            k1,
69            b,
70        }
71    }
72
73    /// 添加文档到索引
74    pub fn add_document(&self, doc_id: String, content: &str) {
75        let tokens = self.tokenize(content);
76        let mut tf: HashMap<String, usize> = HashMap::new();
77
78        for token in tokens {
79            *tf.entry(token).or_insert(0) += 1;
80        }
81
82        let doc_length = content.split_whitespace().count();
83
84        {
85            let mut documents = self.documents.write();
86            documents.insert(doc_id.clone(), content.to_lowercase());
87        }
88
89        {
90            let mut term_frequencies = self.term_frequencies.write();
91            term_frequencies.insert(doc_id, tf);
92        }
93
94        // 更新统计信息
95        {
96            let mut avg_len = self.avg_doc_length.write();
97            let mut count = self.doc_count.write();
98
99            let old_count = *count;
100            let old_avg = *avg_len;
101            let new_count = old_count + 1;
102            *avg_len = (old_avg * old_count as f64 + doc_length as f64) / new_count as f64;
103            *count = new_count;
104        }
105
106        // 清除 IDF 缓存(需要重新计算)
107        self.idf_cache.write().clear();
108    }
109
110    /// 批量添加文档
111    pub fn add_documents(&self, docs: Vec<(String, String)>) {
112        for (doc_id, content) in docs {
113            self.add_document(doc_id, &content);
114        }
115    }
116
117    /// 删除文档
118    pub fn remove_document(&self, doc_id: &str) -> bool {
119        let removed = {
120            let mut documents = self.documents.write();
121            documents.remove(doc_id).is_some()
122        };
123
124        if removed {
125            let mut term_frequencies = self.term_frequencies.write();
126            term_frequencies.remove(doc_id);
127
128            // 更新文档计数
129            {
130                let mut count = self.doc_count.write();
131                if *count > 0 {
132                    *count -= 1;
133                }
134            }
135
136            // 清除 IDF 缓存
137            self.idf_cache.write().clear();
138        }
139
140        removed
141    }
142
143    /// 清空索引
144    pub fn clear(&self) {
145        self.documents.write().clear();
146        self.term_frequencies.write().clear();
147        self.idf_cache.write().clear();
148        *self.avg_doc_length.write() = 0.0;
149        *self.doc_count.write() = 0;
150    }
151
152    /// BM25 搜索
153    pub fn search(&self, query: &str, top_k: usize) -> Vec<(String, f64)> {
154        let query_tokens = self.tokenize(query);
155
156        if query_tokens.is_empty() {
157            return Vec::new();
158        }
159
160        let documents = self.documents.read();
161        let term_frequencies = self.term_frequencies.read();
162        let avg_doc_length = *self.avg_doc_length.read();
163        let doc_count = *self.doc_count.read();
164
165        if doc_count == 0 {
166            return Vec::new();
167        }
168
169        let mut scores: Vec<(String, f64)> = documents
170            .keys()
171            .filter_map(|doc_id| {
172                let score = self.compute_bm25_score(
173                    doc_id,
174                    &query_tokens,
175                    &term_frequencies,
176                    avg_doc_length,
177                    doc_count,
178                );
179                if score > 0.0 {
180                    Some((doc_id.clone(), score))
181                } else {
182                    None
183                }
184            })
185            .collect();
186
187        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
188        scores.truncate(top_k);
189
190        scores
191    }
192
193    /// 计算 BM25 分数
194    fn compute_bm25_score(
195        &self,
196        doc_id: &str,
197        query_tokens: &[String],
198        term_frequencies: &HashMap<String, HashMap<String, usize>>,
199        avg_doc_length: f64,
200        doc_count: usize,
201    ) -> f64 {
202        let doc_tf = match term_frequencies.get(doc_id) {
203            Some(tf) => tf,
204            None => return 0.0,
205        };
206
207        let documents = self.documents.read();
208        let doc_content = match documents.get(doc_id) {
209            Some(content) => content,
210            None => return 0.0,
211        };
212
213        let doc_length = doc_content.split_whitespace().count() as f64;
214        let mut idf_cache = self.idf_cache.write();
215
216        let mut score = 0.0;
217
218        for token in query_tokens {
219            let tf = *doc_tf.get(token).unwrap_or(&0) as f64;
220
221            if tf == 0.0 {
222                continue;
223            }
224
225            // 计算 IDF
226            let idf = *idf_cache.entry(token.clone()).or_insert_with(|| {
227                let df = self.compute_document_frequency(token);
228                let n = doc_count as f64;
229                ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
230            });
231
232            // BM25 公式
233            let numerator = tf * (self.k1 + 1.0);
234            let denominator =
235                tf + self.k1 * (1.0 - self.b + self.b * (doc_length / avg_doc_length));
236
237            score += idf * (numerator / denominator);
238        }
239
240        score
241    }
242
243    /// 计算词项的文档频率
244    fn compute_document_frequency(&self, term: &str) -> f64 {
245        let term_frequencies = self.term_frequencies.read();
246        term_frequencies
247            .values()
248            .filter(|tf| tf.contains_key(term))
249            .count() as f64
250    }
251
252    /// 分词
253    fn tokenize(&self, text: &str) -> Vec<String> {
254        let stop_words: HashSet<&str> = [
255            "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
256            "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
257            "shall", "can", "need", "dare", "ought", "used", "to", "of", "in", "for", "on", "with",
258            "at", "by", "from", "as", "into", "through", "during", "before", "after", "above",
259            "below", "between", "under", "again", "further", "then", "once", "here", "there",
260            "when", "where", "why", "how", "all", "each", "few", "more", "most", "other", "some",
261            "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s",
262            "t", "just", "and", "but", "if", "or", "because", "until", "while", "although",
263        ]
264        .iter()
265        .cloned()
266        .collect();
267
268        text.to_lowercase()
269            .split_whitespace()
270            .filter(|w| !stop_words.contains(*w) && w.len() > 1)
271            .map(|s| s.to_string())
272            .collect()
273    }
274
275    /// 获取文档数量
276    pub fn doc_count(&self) -> usize {
277        *self.doc_count.read()
278    }
279}
280
281impl Default for BM25Index {
282    fn default() -> Self {
283        Self::new()
284    }
285}
286
287// ============================================================================
288// Reciprocal Rank Fusion
289// ============================================================================
290
291/// Reciprocal Rank Fusion (RRF) 融合器
292///
293/// 将多个检索结果列表融合为一个排序结果。
294pub struct ReciprocalRankFusion {
295    /// RRF 参数 K (控制排名衰减)
296    k: f64,
297}
298
299impl ReciprocalRankFusion {
300    /// 创建新的 RRF 融合器
301    pub fn new(k: f64) -> Self {
302        Self { k }
303    }
304
305    /// 使用默认参数创建 (k=60)
306    pub fn default_fusion() -> Self {
307        Self::new(60.0)
308    }
309
310    /// 融合多个检索结果
311    ///
312    /// # Arguments
313    /// * `result_lists` - 多个检索结果列表,每个列表包含 (doc_id, score) 元组
314    /// * `top_k` - 返回的结果数量
315    ///
316    /// # Returns
317    /// 融合后的排序结果列表
318    pub fn fuse(&self, result_lists: &[Vec<(String, f64)>], top_k: usize) -> Vec<(String, f64)> {
319        let mut rrf_scores: HashMap<String, f64> = HashMap::new();
320
321        for results in result_lists {
322            for (rank, (doc_id, _original_score)) in results.iter().enumerate() {
323                let rrf_score = 1.0 / (self.k + (rank + 1) as f64);
324                *rrf_scores.entry(doc_id.clone()).or_insert(0.0) += rrf_score;
325            }
326        }
327
328        let mut fused: Vec<(String, f64)> = rrf_scores.into_iter().collect();
329        fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
330        fused.truncate(top_k);
331
332        fused
333    }
334
335    /// 融合并保留原始分数权重
336    ///
337    /// # Arguments
338    /// * `result_lists` - 多个检索结果列表
339    /// * `weights` - 每个结果列表的权重
340    /// * `top_k` - 返回的结果数量
341    pub fn fuse_with_weights(
342        &self,
343        result_lists: &[Vec<(String, f64)>],
344        weights: &[f64],
345        top_k: usize,
346    ) -> Vec<(String, f64)> {
347        if result_lists.len() != weights.len() {
348            panic!("Result lists and weights must have the same length");
349        }
350
351        let mut combined_scores: HashMap<String, f64> = HashMap::new();
352
353        for (results, weight) in result_lists.iter().zip(weights.iter()) {
354            for (rank, (doc_id, original_score)) in results.iter().enumerate() {
355                let rrf_score = 1.0 / (self.k + (rank + 1) as f64);
356                let weighted_score = (rrf_score + original_score * 0.1) * weight;
357                *combined_scores.entry(doc_id.clone()).or_insert(0.0) += weighted_score;
358            }
359        }
360
361        let mut fused: Vec<(String, f64)> = combined_scores.into_iter().collect();
362        fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
363        fused.truncate(top_k);
364
365        fused
366    }
367}
368
369impl Default for ReciprocalRankFusion {
370    fn default() -> Self {
371        Self::default_fusion()
372    }
373}
374
375// ============================================================================
376// Hybrid Retriever Configuration
377// ============================================================================
378
379/// 混合检索配置
380#[derive(Debug, Clone)]
381pub struct HybridRetrieverConfig {
382    /// 向量检索权重
383    pub vector_weight: f64,
384    /// BM25 检索权重
385    pub bm25_weight: f64,
386    /// RRF 参数 K
387    pub rrf_k: f64,
388    /// 是否启用 RRF 融合
389    pub use_rrf: bool,
390    /// 候选结果扩展倍数
391    pub candidate_multiplier: usize,
392    /// 最小分数阈值
393    pub min_score_threshold: f64,
394}
395
396impl HybridRetrieverConfig {
397    /// 创建默认配置
398    pub fn new() -> Self {
399        Self {
400            vector_weight: 0.7,
401            bm25_weight: 0.3,
402            rrf_k: 60.0,
403            use_rrf: true,
404            candidate_multiplier: 2,
405            min_score_threshold: 0.0,
406        }
407    }
408
409    /// 创建仅向量检索配置
410    pub fn vector_only() -> Self {
411        Self {
412            vector_weight: 1.0,
413            bm25_weight: 0.0,
414            ..Self::new()
415        }
416    }
417
418    /// 创建仅 BM25 检索配置
419    pub fn bm25_only() -> Self {
420        Self {
421            vector_weight: 0.0,
422            bm25_weight: 1.0,
423            ..Self::new()
424        }
425    }
426
427    /// 创建均衡配置
428    pub fn balanced() -> Self {
429        Self {
430            vector_weight: 0.5,
431            bm25_weight: 0.5,
432            ..Self::new()
433        }
434    }
435
436    /// 设置权重
437    pub fn with_weights(mut self, vector: f64, bm25: f64) -> Self {
438        let total = vector + bm25;
439        self.vector_weight = vector / total;
440        self.bm25_weight = bm25 / total;
441        self
442    }
443
444    /// 设置 RRF 参数
445    pub fn with_rrf(mut self, enabled: bool, k: f64) -> Self {
446        self.use_rrf = enabled;
447        self.rrf_k = k;
448        self
449    }
450
451    /// 设置候选扩展倍数
452    pub fn with_candidate_multiplier(mut self, multiplier: usize) -> Self {
453        self.candidate_multiplier = multiplier;
454        self
455    }
456
457    /// 设置最小分数阈值
458    pub fn with_min_score(mut self, threshold: f64) -> Self {
459        self.min_score_threshold = threshold;
460        self
461    }
462
463    /// 归一化权重
464    pub fn normalize_weights(&mut self) {
465        let total = self.vector_weight + self.bm25_weight;
466        if total > 0.0 {
467            self.vector_weight /= total;
468            self.bm25_weight /= total;
469        }
470    }
471}
472
473impl Default for HybridRetrieverConfig {
474    fn default() -> Self {
475        Self::new()
476    }
477}
478
479// ============================================================================
480// Hybrid Retriever Trait
481// ============================================================================
482
483/// 混合检索器 trait
484#[async_trait]
485pub trait HybridRetriever: Send + Sync {
486    /// 索引文档
487    async fn index_documents(&self, documents: Vec<Document>) -> Layer3Result<Vec<String>>;
488
489    /// 混合检索
490    async fn retrieve(
491        &self,
492        query: &str,
493        top_k: usize,
494        config: Option<&HybridRetrieverConfig>,
495    ) -> Layer3Result<Vec<RetrievalResult>>;
496
497    /// 带过滤条件的混合检索
498    async fn retrieve_with_filter(
499        &self,
500        query: &str,
501        top_k: usize,
502        filter: Option<MetadataFilter>,
503        config: Option<&HybridRetrieverConfig>,
504    ) -> Layer3Result<Vec<RetrievalResult>>;
505
506    /// 删除文档
507    async fn delete_documents(&self, doc_ids: &[String]) -> Layer3Result<bool>;
508
509    /// 清空索引
510    async fn clear(&self) -> Layer3Result<bool>;
511
512    /// 获取文档数量
513    async fn count(&self) -> Layer3Result<usize>;
514}
515
516// ============================================================================
517// Default Hybrid Retriever Implementation
518// ============================================================================
519
520/// 文档缓存条目类型
521type DocCacheEntry = (String, HashMap<String, serde_json::Value>);
522
523/// 默认混合检索器实现
524///
525/// 结合 BM25 和向量检索,使用 RRF 融合结果。
526pub struct DefaultHybridRetriever<VS, EM>
527where
528    VS: VectorStore,
529    EM: EmbeddingModel,
530{
531    /// 向量存储
532    vector_store: VS,
533    /// Embedding 模型
534    embedding_model: EM,
535    /// BM25 索引
536    bm25_index: BM25Index,
537    /// 文档内容缓存
538    doc_cache: Arc<RwLock<HashMap<String, DocCacheEntry>>>,
539    /// 默认配置
540    default_config: HybridRetrieverConfig,
541}
542
543impl<VS, EM> DefaultHybridRetriever<VS, EM>
544where
545    VS: VectorStore,
546    EM: EmbeddingModel,
547{
548    /// 创建新的混合检索器
549    pub fn new(vector_store: VS, embedding_model: EM) -> Self {
550        Self {
551            vector_store,
552            embedding_model,
553            bm25_index: BM25Index::new(),
554            doc_cache: Arc::new(RwLock::new(HashMap::new())),
555            default_config: HybridRetrieverConfig::new(),
556        }
557    }
558
559    /// 使用自定义配置创建
560    pub fn with_config(
561        vector_store: VS,
562        embedding_model: EM,
563        config: HybridRetrieverConfig,
564    ) -> Self {
565        Self {
566            vector_store,
567            embedding_model,
568            bm25_index: BM25Index::new(),
569            doc_cache: Arc::new(RwLock::new(HashMap::new())),
570            default_config: config,
571        }
572    }
573
574    /// 向量检索
575    #[instrument(skip(self))]
576    async fn vector_search(&self, query: &str, top_k: usize) -> Layer3Result<Vec<(String, f64)>> {
577        let query_embedding = self.embedding_model.embed(query).await?;
578        let results = self.vector_store.query(query_embedding, top_k).await?;
579
580        Ok(results
581            .into_iter()
582            .map(|r| (r.doc_id, r.score as f64))
583            .collect())
584    }
585
586    /// BM25 检索
587    #[instrument(skip(self))]
588    fn bm25_search(&self, query: &str, top_k: usize) -> Vec<(String, f64)> {
589        self.bm25_index.search(query, top_k)
590    }
591
592    /// 获取文档内容
593    fn get_document_content(
594        &self,
595        doc_id: &str,
596    ) -> Option<(String, HashMap<String, serde_json::Value>)> {
597        self.doc_cache.read().get(doc_id).cloned()
598    }
599
600    /// 应用分数阈值过滤
601    #[allow(dead_code)]
602    fn apply_threshold(&self, results: Vec<(String, f64)>, threshold: f64) -> Vec<(String, f64)> {
603        results
604            .into_iter()
605            .filter(|(_, score)| *score >= threshold)
606            .collect()
607    }
608}
609
610#[async_trait]
611impl<VS, EM> HybridRetriever for DefaultHybridRetriever<VS, EM>
612where
613    VS: VectorStore,
614    EM: EmbeddingModel,
615{
616    #[instrument(skip(self, documents))]
617    async fn index_documents(&self, documents: Vec<Document>) -> Layer3Result<Vec<String>> {
618        use crate::vector_store::VectorItem;
619
620        let mut doc_ids = Vec::new();
621        let mut vector_items = Vec::new();
622        let mut bm25_docs = Vec::new();
623
624        for doc in documents {
625            let doc_id = doc.id.unwrap_or_else(generate_short_id);
626
627            // 缓存文档内容
628            {
629                let mut cache = self.doc_cache.write();
630                cache.insert(doc_id.clone(), (doc.content.clone(), doc.metadata.clone()));
631            }
632
633            // 添加到 BM25 索引
634            bm25_docs.push((doc_id.clone(), doc.content.clone()));
635
636            // 生成 embedding
637            let embedding = self.embedding_model.embed(&doc.content).await?;
638
639            let mut metadata = doc.metadata.clone();
640            if let Some(source) = doc.source {
641                metadata.insert("source".to_string(), serde_json::json!(source));
642            }
643
644            vector_items.push(VectorItem {
645                id: doc_id.clone(),
646                vector: embedding,
647                metadata,
648                content: Some(doc.content),
649            });
650
651            doc_ids.push(doc_id);
652        }
653
654        // 批量添加到 BM25 索引
655        self.bm25_index.add_documents(bm25_docs);
656
657        // 批量添加到向量存储
658        self.vector_store.add_batch(vector_items).await?;
659
660        Ok(doc_ids)
661    }
662
663    #[instrument(skip(self))]
664    async fn retrieve(
665        &self,
666        query: &str,
667        top_k: usize,
668        config: Option<&HybridRetrieverConfig>,
669    ) -> Layer3Result<Vec<RetrievalResult>> {
670        let config = config.unwrap_or(&self.default_config);
671
672        // 计算候选结果数量
673        let candidates = top_k * config.candidate_multiplier;
674
675        // 收集检索结果
676        let mut result_lists: Vec<Vec<(String, f64)>> = Vec::new();
677        let mut weights: Vec<f64> = Vec::new();
678
679        // 向量检索
680        if config.vector_weight > 0.0 {
681            let vector_results = self.vector_search(query, candidates).await?;
682            result_lists.push(vector_results);
683            weights.push(config.vector_weight);
684        }
685
686        // BM25 检索
687        if config.bm25_weight > 0.0 {
688            let bm25_results = self.bm25_search(query, candidates);
689            result_lists.push(bm25_results);
690            weights.push(config.bm25_weight);
691        }
692
693        // 如果只有一个检索器,直接返回
694        if result_lists.len() == 1 {
695            let results = result_lists.remove(0);
696            let final_results: Vec<RetrievalResult> = results
697                .into_iter()
698                .take(top_k)
699                .filter_map(|(doc_id, score)| {
700                    let (content, metadata) = self.get_document_content(&doc_id)?;
701                    let source = metadata
702                        .get("source")
703                        .and_then(|v| v.as_str())
704                        .map(String::from);
705                    Some(RetrievalResult {
706                        doc_id,
707                        content,
708                        score: score as f32,
709                        metadata,
710                        source,
711                    })
712                })
713                .collect();
714
715            return Ok(final_results);
716        }
717
718        // 融合结果
719        let fused_results = if config.use_rrf {
720            let rrf = ReciprocalRankFusion::new(config.rrf_k);
721            rrf.fuse_with_weights(&result_lists, &weights, top_k)
722        } else {
723            // 简单加权融合
724            let mut combined: HashMap<String, f64> = HashMap::new();
725            for (results, weight) in result_lists.iter().zip(weights.iter()) {
726                for (doc_id, score) in results {
727                    *combined.entry(doc_id.clone()).or_insert(0.0) += score * weight;
728                }
729            }
730            let mut fused: Vec<(String, f64)> = combined.into_iter().collect();
731            fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
732            fused.truncate(top_k);
733            fused
734        };
735
736        // 构建最终结果
737        let final_results: Vec<RetrievalResult> = fused_results
738            .into_iter()
739            .filter_map(|(doc_id, score)| {
740                let (content, metadata) = self.get_document_content(&doc_id)?;
741                let source = metadata
742                    .get("source")
743                    .and_then(|v| v.as_str())
744                    .map(String::from);
745                Some(RetrievalResult {
746                    doc_id,
747                    content,
748                    score: score as f32,
749                    metadata,
750                    source,
751                })
752            })
753            .collect();
754
755        Ok(final_results)
756    }
757
758    async fn retrieve_with_filter(
759        &self,
760        query: &str,
761        top_k: usize,
762        filter: Option<MetadataFilter>,
763        config: Option<&HybridRetrieverConfig>,
764    ) -> Layer3Result<Vec<RetrievalResult>> {
765        let config = config.unwrap_or(&self.default_config);
766        let candidates = top_k * config.candidate_multiplier * 2;
767
768        // 获取更多候选结果
769        let mut results = self.retrieve(query, candidates, Some(config)).await?;
770
771        // 应用过滤器
772        if let Some(f) = filter {
773            results.retain(|r| {
774                // 简单的元数据匹配检查
775                f.must
776                    .iter()
777                    .all(|(key, value)| r.metadata.get(key) == Some(value))
778            });
779        }
780
781        // 应用分数阈值
782        if config.min_score_threshold > 0.0 {
783            results.retain(|r| r.score >= config.min_score_threshold as f32);
784        }
785
786        results.truncate(top_k);
787        Ok(results)
788    }
789
790    async fn delete_documents(&self, doc_ids: &[String]) -> Layer3Result<bool> {
791        // 从向量存储删除
792        self.vector_store.delete_batch(doc_ids).await?;
793
794        // 从 BM25 索引删除
795        for doc_id in doc_ids {
796            self.bm25_index.remove_document(doc_id);
797        }
798
799        // 从缓存删除
800        {
801            let mut cache = self.doc_cache.write();
802            for doc_id in doc_ids {
803                cache.remove(doc_id);
804            }
805        }
806
807        Ok(true)
808    }
809
810    async fn clear(&self) -> Layer3Result<bool> {
811        self.vector_store.clear().await?;
812        self.bm25_index.clear();
813        self.doc_cache.write().clear();
814        Ok(true)
815    }
816
817    async fn count(&self) -> Layer3Result<usize> {
818        Ok(self.bm25_index.doc_count())
819    }
820}
821
822// ============================================================================
823// Tests
824// ============================================================================
825
826#[cfg(test)]
827mod tests {
828    use super::*;
829    use crate::retriever_engine::Layer1EmbeddingAdapter;
830    use crate::vector_store::InMemoryVectorStore;
831
832    /// 创建测试用的 Mock Embedding 模型
833    /// 使用 Layer1 的 MockEmbeddingModel 通过适配器
834    fn create_mock_embedding_model(dimension: usize) -> Layer1EmbeddingAdapter {
835        Layer1EmbeddingAdapter::new(Box::new(sh_layer1::MockEmbeddingModel::new(dimension)))
836    }
837
838    #[test]
839    fn test_bm25_index_basic() {
840        let index = BM25Index::new();
841
842        index.add_document("doc1".to_string(), "Rust is a systems programming language");
843        index.add_document("doc2".to_string(), "Python is used for data science");
844        index.add_document("doc3".to_string(), "JavaScript runs in the browser");
845
846        let results = index.search("Rust programming", 5);
847        assert!(!results.is_empty());
848        assert_eq!(results[0].0, "doc1");
849    }
850
851    #[test]
852    fn test_bm25_index_scoring() {
853        let index = BM25Index::new();
854
855        index.add_document("doc1".to_string(), "machine learning algorithms");
856        index.add_document("doc2".to_string(), "deep learning neural networks");
857        index.add_document("doc3".to_string(), "database systems");
858
859        let results = index.search("machine learning", 3);
860        assert!(!results.is_empty());
861
862        // doc1 和 doc2 都包含 "learning",但 doc1 包含 "machine"
863        assert!(results.iter().any(|(id, _)| id == "doc1"));
864    }
865
866    #[test]
867    fn test_bm25_remove_document() {
868        let index = BM25Index::new();
869
870        index.add_document("doc1".to_string(), "test document");
871        assert_eq!(index.doc_count(), 1);
872
873        let removed = index.remove_document("doc1");
874        assert!(removed);
875        assert_eq!(index.doc_count(), 0);
876
877        let removed = index.remove_document("nonexistent");
878        assert!(!removed);
879    }
880
881    #[test]
882    fn test_rrf_fusion() {
883        let rrf = ReciprocalRankFusion::default_fusion();
884
885        let list1 = vec![
886            ("doc1".to_string(), 0.9),
887            ("doc2".to_string(), 0.8),
888            ("doc3".to_string(), 0.7),
889        ];
890
891        let list2 = vec![
892            ("doc3".to_string(), 0.95),
893            ("doc1".to_string(), 0.85),
894            ("doc4".to_string(), 0.75),
895        ];
896
897        let fused = rrf.fuse(&[list1, list2], 5);
898
899        assert!(!fused.is_empty());
900        // doc1 和 doc3 都在两个列表中出现,应该排名靠前
901        assert!(fused
902            .iter()
903            .take(2)
904            .any(|(id, _)| id == "doc1" || id == "doc3"));
905    }
906
907    #[test]
908    fn test_rrf_with_weights() {
909        let rrf = ReciprocalRankFusion::new(60.0);
910
911        let list1 = vec![("doc1".to_string(), 0.9)];
912        let list2 = vec![("doc2".to_string(), 0.9)];
913
914        let fused = rrf.fuse_with_weights(&[list1, list2], &[0.7, 0.3], 5);
915        assert!(!fused.is_empty());
916    }
917
918    #[test]
919    fn test_hybrid_retriever_config() {
920        let config = HybridRetrieverConfig::new();
921        assert_eq!(config.vector_weight, 0.7);
922        assert_eq!(config.bm25_weight, 0.3);
923        assert!(config.use_rrf);
924
925        let vector_only = HybridRetrieverConfig::vector_only();
926        assert_eq!(vector_only.vector_weight, 1.0);
927        assert_eq!(vector_only.bm25_weight, 0.0);
928
929        let balanced = HybridRetrieverConfig::balanced();
930        assert_eq!(balanced.vector_weight, 0.5);
931        assert_eq!(balanced.bm25_weight, 0.5);
932
933        let custom = HybridRetrieverConfig::new().with_weights(0.8, 0.2);
934        assert!((custom.vector_weight - 0.8).abs() < 0.001);
935        assert!((custom.bm25_weight - 0.2).abs() < 0.001);
936    }
937
938    #[tokio::test]
939    async fn test_hybrid_retriever_index_and_search() {
940        let vector_store = InMemoryVectorStore::in_memory();
941        let embedding_model = create_mock_embedding_model(128);
942
943        let retriever = DefaultHybridRetriever::new(vector_store, embedding_model);
944
945        let docs = vec![
946            Document::new("Rust is a systems programming language"),
947            Document::new("Python is widely used for data science"),
948            Document::new("JavaScript runs in the browser"),
949        ];
950
951        let doc_ids = retriever.index_documents(docs).await.unwrap();
952        assert_eq!(doc_ids.len(), 3);
953
954        let results = retriever
955            .retrieve("Rust programming", 5, None)
956            .await
957            .unwrap();
958        assert!(!results.is_empty());
959    }
960
961    #[tokio::test]
962    async fn test_hybrid_retriever_with_config() {
963        let vector_store = InMemoryVectorStore::in_memory();
964        let embedding_model = create_mock_embedding_model(128);
965
966        let retriever = DefaultHybridRetriever::new(vector_store, embedding_model);
967
968        retriever
969            .index_documents(vec![
970                Document::new("Machine learning algorithms use neural networks"),
971                Document::new("Database stores data for applications"),
972            ])
973            .await
974            .unwrap();
975
976        // 测试仅向量检索
977        let config = HybridRetrieverConfig::vector_only();
978        let results = retriever
979            .retrieve("neural networks", 5, Some(&config))
980            .await
981            .unwrap();
982        assert!(!results.is_empty());
983
984        // 测试仅 BM25 检索
985        let config = HybridRetrieverConfig::bm25_only();
986        let results = retriever
987            .retrieve("machine learning", 5, Some(&config))
988            .await
989            .unwrap();
990        assert!(!results.is_empty());
991
992        // 测试均衡检索
993        let config = HybridRetrieverConfig::balanced().with_rrf(true, 60.0);
994        let results = retriever
995            .retrieve("database", 5, Some(&config))
996            .await
997            .unwrap();
998        assert!(!results.is_empty());
999    }
1000
1001    #[tokio::test]
1002    async fn test_hybrid_retriever_delete_and_count() {
1003        let vector_store = InMemoryVectorStore::in_memory();
1004        let embedding_model = create_mock_embedding_model(128);
1005
1006        let retriever = DefaultHybridRetriever::new(vector_store, embedding_model);
1007
1008        let doc_ids = retriever
1009            .index_documents(vec![Document::new("Test document")])
1010            .await
1011            .unwrap();
1012
1013        assert_eq!(retriever.count().await.unwrap(), 1);
1014
1015        retriever.delete_documents(&doc_ids).await.unwrap();
1016        assert_eq!(retriever.count().await.unwrap(), 0);
1017    }
1018
1019    #[tokio::test]
1020    async fn test_hybrid_retriever_clear() {
1021        let vector_store = InMemoryVectorStore::in_memory();
1022        let embedding_model = create_mock_embedding_model(128);
1023
1024        let retriever = DefaultHybridRetriever::new(vector_store, embedding_model);
1025
1026        retriever
1027            .index_documents(vec![Document::new("Doc 1"), Document::new("Doc 2")])
1028            .await
1029            .unwrap();
1030
1031        assert_eq!(retriever.count().await.unwrap(), 2);
1032
1033        retriever.clear().await.unwrap();
1034        assert_eq!(retriever.count().await.unwrap(), 0);
1035    }
1036}