1use crate::EmbeddingModel;
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9use tracing::{debug, info};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Document {
14 pub id: String,
16 pub path: PathBuf,
18 pub content: String,
20}
21
22#[derive(Debug, Clone)]
24pub struct SearchResult {
25 pub document: Document,
27 pub score: f32,
29}
30
31pub struct VectorSearcher {
33 model: EmbeddingModel,
35 documents: Vec<Document>,
37 index: Vec<Vec<f32>>,
39}
40
41impl VectorSearcher {
42 pub fn new(model: EmbeddingModel) -> Self {
50 Self { model, documents: Vec::new(), index: Vec::new() }
51 }
52
53 pub fn add_document(&mut self, doc: Document) -> Result<()> {
64 info!("添加文档到索引: {:?}", doc.path);
65
66 let embedding = self.model.encode(&doc.content)?;
68
69 self.documents.push(doc);
70 self.index.push(embedding);
71
72 debug!("当前索引文档数: {}", self.documents.len());
73 Ok(())
74 }
75
76 pub fn add_documents(&mut self, docs: Vec<Document>) -> Result<usize> {
84 let total = docs.len();
85 info!("批量添加 {} 个文档到索引", total);
86
87 let mut success_count = 0;
88 for doc in docs {
89 if self.add_document(doc).is_ok() {
90 success_count += 1;
91 }
92 }
93
94 info!("成功添加 {}/{} 个文档", success_count, total);
95 Ok(success_count)
96 }
97
98 pub fn search(&mut self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
107 if self.documents.is_empty() {
108 return Ok(Vec::new());
109 }
110
111 debug!("语义搜索: \"{}\"", query);
112
113 let query_embedding = self.model.encode(query)?;
115
116 let mut scores: Vec<(usize, f32)> = Vec::new();
118 for (i, doc_embedding) in self.index.iter().enumerate() {
119 let score = EmbeddingModel::cosine_similarity(&query_embedding, doc_embedding);
120 scores.push((i, score));
121 }
122
123 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
125
126 let results: Vec<SearchResult> = scores
128 .into_iter()
129 .take(top_k)
130 .map(|(i, score)| SearchResult { document: self.documents[i].clone(), score })
131 .collect();
132
133 debug!("找到 {} 个结果", results.len());
134 Ok(results)
135 }
136
137 pub fn document_count(&self) -> usize {
139 self.documents.len()
140 }
141
142 pub fn clear(&mut self) {
144 self.documents.clear();
145 self.index.clear();
146 }
147}
148
149pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
158 if a.len() != b.len() || a.is_empty() {
159 return 0.0;
160 }
161
162 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
163 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
164 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
165
166 if norm_a > 0.0 && norm_b > 0.0 {
167 dot / (norm_a * norm_b)
168 } else {
169 0.0
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn test_cosine_similarity() {
179 let a = vec![1.0, 0.0, 0.0];
180 let b = vec![1.0, 0.0, 0.0];
181 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
182
183 let c = vec![1.0, 0.0, 0.0];
184 let d = vec![0.0, 1.0, 0.0];
185 assert!((cosine_similarity(&c, &d) - 0.0).abs() < 1e-6);
186
187 let e = vec![1.0, 1.0];
188 let f = vec![-1.0, -1.0];
189 assert!((cosine_similarity(&e, &f) + 1.0).abs() < 1e-6);
190 }
191
192 #[test]
193 fn test_cosine_similarity_edge_cases() {
194 let empty: Vec<f32> = vec![];
196 assert_eq!(cosine_similarity(&empty, &empty), 0.0);
197
198 let a = vec![1.0, 2.0];
200 let b = vec![1.0, 2.0, 3.0];
201 assert_eq!(cosine_similarity(&a, &b), 0.0);
202
203 let zero = vec![0.0, 0.0, 0.0];
205 let nonzero = vec![1.0, 2.0, 3.0];
206 assert_eq!(cosine_similarity(&zero, &nonzero), 0.0);
207 }
208}