Skip to main content

xore_ai/
search.rs

1//! 向量搜索引擎
2//!
3//! 基于嵌入向量的语义搜索
4
5use crate::EmbeddingModel;
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9use tracing::{debug, info};
10
11/// 文档
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Document {
14    /// 文档 ID
15    pub id: String,
16    /// 文件路径
17    pub path: PathBuf,
18    /// 文档内容
19    pub content: String,
20}
21
22/// 搜索结果
23#[derive(Debug, Clone)]
24pub struct SearchResult {
25    /// 文档
26    pub document: Document,
27    /// 相似度分数 [0, 1]
28    pub score: f32,
29}
30
31/// 向量搜索引擎
32pub struct VectorSearcher {
33    /// 嵌入模型
34    model: EmbeddingModel,
35    /// 文档集合
36    documents: Vec<Document>,
37    /// 预计算的嵌入向量
38    index: Vec<Vec<f32>>,
39}
40
41impl VectorSearcher {
42    /// 创建新的向量搜索引擎
43    ///
44    /// # 参数
45    /// - `model`: 嵌入模型
46    ///
47    /// # 返回
48    /// 向量搜索引擎实例
49    pub fn new(model: EmbeddingModel) -> Self {
50        Self { model, documents: Vec::new(), index: Vec::new() }
51    }
52
53    /// 添加文档到索引
54    ///
55    /// # 参数
56    /// - `doc`: 文档
57    ///
58    /// # 返回
59    /// 成功或错误
60    ///
61    /// # 注意
62    /// 需要 &mut self 因为 encode() 需要可变引用
63    pub fn add_document(&mut self, doc: Document) -> Result<()> {
64        info!("添加文档到索引: {:?}", doc.path);
65
66        // 生成嵌入向量(需要 &mut self.model)
67        let embedding = self.model.encode(&doc.content)?;
68
69        self.documents.push(doc);
70        self.index.push(embedding);
71
72        debug!("当前索引文档数: {}", self.documents.len());
73        Ok(())
74    }
75
76    /// 批量添加文档
77    ///
78    /// # 参数
79    /// - `docs`: 文档列表
80    ///
81    /// # 返回
82    /// 成功添加的文档数量
83    pub fn add_documents(&mut self, docs: Vec<Document>) -> Result<usize> {
84        let total = docs.len();
85        info!("批量添加 {} 个文档到索引", total);
86
87        let mut success_count = 0;
88        for doc in docs {
89            if self.add_document(doc).is_ok() {
90                success_count += 1;
91            }
92        }
93
94        info!("成功添加 {}/{} 个文档", success_count, total);
95        Ok(success_count)
96    }
97
98    /// 语义搜索
99    ///
100    /// # 参数
101    /// - `query`: 查询文本
102    /// - `top_k`: 返回结果数量
103    ///
104    /// # 返回
105    /// 搜索结果列表(按相似度降序)
106    pub fn search(&mut self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
107        if self.documents.is_empty() {
108            return Ok(Vec::new());
109        }
110
111        debug!("语义搜索: \"{}\"", query);
112
113        // 1. 查询向量化
114        let query_embedding = self.model.encode(query)?;
115
116        // 2. 计算余弦相似度
117        let mut scores: Vec<(usize, f32)> = Vec::new();
118        for (i, doc_embedding) in self.index.iter().enumerate() {
119            let score = EmbeddingModel::cosine_similarity(&query_embedding, doc_embedding);
120            scores.push((i, score));
121        }
122
123        // 3. 排序取 top_k
124        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
125
126        // 4. 返回结果
127        let results: Vec<SearchResult> = scores
128            .into_iter()
129            .take(top_k)
130            .map(|(i, score)| SearchResult { document: self.documents[i].clone(), score })
131            .collect();
132
133        debug!("找到 {} 个结果", results.len());
134        Ok(results)
135    }
136
137    /// 获取索引中的文档数量
138    pub fn document_count(&self) -> usize {
139        self.documents.len()
140    }
141
142    /// 清空索引
143    pub fn clear(&mut self) {
144        self.documents.clear();
145        self.index.clear();
146    }
147}
148
149/// 计算余弦相似度(独立函数)
150///
151/// # 参数
152/// - `a`: 向量 A
153/// - `b`: 向量 B
154///
155/// # 返回
156/// 余弦相似度 [-1, 1]
157pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
158    if a.len() != b.len() || a.is_empty() {
159        return 0.0;
160    }
161
162    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
163    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
164    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
165
166    if norm_a > 0.0 && norm_b > 0.0 {
167        dot / (norm_a * norm_b)
168    } else {
169        0.0
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn test_cosine_similarity() {
179        let a = vec![1.0, 0.0, 0.0];
180        let b = vec![1.0, 0.0, 0.0];
181        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
182
183        let c = vec![1.0, 0.0, 0.0];
184        let d = vec![0.0, 1.0, 0.0];
185        assert!((cosine_similarity(&c, &d) - 0.0).abs() < 1e-6);
186
187        let e = vec![1.0, 1.0];
188        let f = vec![-1.0, -1.0];
189        assert!((cosine_similarity(&e, &f) + 1.0).abs() < 1e-6);
190    }
191
192    #[test]
193    fn test_cosine_similarity_edge_cases() {
194        // 空向量
195        let empty: Vec<f32> = vec![];
196        assert_eq!(cosine_similarity(&empty, &empty), 0.0);
197
198        // 长度不匹配
199        let a = vec![1.0, 2.0];
200        let b = vec![1.0, 2.0, 3.0];
201        assert_eq!(cosine_similarity(&a, &b), 0.0);
202
203        // 零向量
204        let zero = vec![0.0, 0.0, 0.0];
205        let nonzero = vec![1.0, 2.0, 3.0];
206        assert_eq!(cosine_similarity(&zero, &nonzero), 0.0);
207    }
208}