1use anyhow::Result;
2use hnsw::Hnsw;
3use rand_pcg::Pcg64;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::{Arc, RwLock};
9
10const HUGGINGFACE_API_URL: &str = "https://router.huggingface.co/hf-inference/models";
11const DEFAULT_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2"; const DEFAULT_DIMENSIONS: usize = 384;
13
14#[derive(Default, Clone)]
16pub struct Euclidian;
17
18impl space::Metric<Vec<f32>> for Euclidian {
19 type Unit = u32;
20 fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
21 let len = a.len().min(b.len());
22 let mut dist_sq = 0.0;
23 for i in 0..len {
24 let diff = a[i] - b[i];
25 dist_sq += diff * diff;
26 }
27 (dist_sq.sqrt() * 1_000_000.0) as u32
29 }
30}
31
32#[derive(Serialize, Deserialize, Default)]
34struct VectorData {
35 entries: Vec<VectorEntry>,
36}
37
38#[derive(Serialize, Deserialize, Clone)]
39struct VectorEntry {
40 uri: String,
41 embedding: Vec<f32>,
42}
43
44pub struct VectorStore {
46 index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
48 id_to_uri: Arc<RwLock<HashMap<usize, String>>>,
50 uri_to_id: Arc<RwLock<HashMap<String, usize>>>,
52 storage_path: Option<PathBuf>,
54 client: Client,
56 api_token: Option<String>,
58 model: String,
60 dimensions: usize,
62 embeddings: Arc<RwLock<Vec<VectorEntry>>>,
64}
65
66#[derive(Debug, Serialize, Deserialize)]
67pub struct SearchResult {
68 pub uri: String,
69 pub score: f32,
70 pub content: String,
71}
72
73impl VectorStore {
74 pub fn new(namespace: &str) -> Result<Self> {
76 let storage_path = std::env::var("GRAPH_STORAGE_PATH")
78 .ok()
79 .map(|p| PathBuf::from(p).join(namespace));
80
81 let dimensions = std::env::var("VECTOR_DIMENSIONS")
83 .ok()
84 .and_then(|s| s.parse().ok())
85 .unwrap_or(DEFAULT_DIMENSIONS);
86
87 let mut index = Hnsw::new(Euclidian);
89 let mut id_to_uri = HashMap::new();
90 let mut uri_to_id = HashMap::new();
91 let mut embeddings = Vec::new();
92
93 if let Some(ref path) = storage_path {
95 let vectors_path = path.join("vectors.json");
96 if vectors_path.exists() {
97 if let Ok(content) = std::fs::read_to_string(&vectors_path) {
98 if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
99 let mut searcher = hnsw::Searcher::default();
100 for entry in data.entries {
101 if entry.embedding.len() == dimensions {
102 let id = index.insert(entry.embedding.clone(), &mut searcher);
103 id_to_uri.insert(id, entry.uri.clone());
104 uri_to_id.insert(entry.uri.clone(), id);
105 embeddings.push(entry);
106 }
107 }
108 eprintln!("Loaded {} vectors from disk (dim={})", embeddings.len(), dimensions);
109 }
110 }
111 }
112 }
113
114 let api_token = std::env::var("HUGGINGFACE_API_TOKEN").ok();
116
117 let client = Client::builder()
119 .timeout(std::time::Duration::from_secs(30))
120 .build()
121 .unwrap_or_else(|_| Client::new());
122
123 Ok(Self {
124 index: Arc::new(RwLock::new(index)),
125 id_to_uri: Arc::new(RwLock::new(id_to_uri)),
126 uri_to_id: Arc::new(RwLock::new(uri_to_id)),
127 storage_path,
128 client,
129 api_token,
130 model: DEFAULT_MODEL.to_string(),
131 dimensions,
132 embeddings: Arc::new(RwLock::new(embeddings)),
133 })
134 }
135
136 fn save_vectors(&self) -> Result<()> {
138 if let Some(ref path) = self.storage_path {
139 std::fs::create_dir_all(path)?;
140 let data = VectorData {
141 entries: self.embeddings.read().unwrap().clone(),
142 };
143 let content = serde_json::to_string(&data)?;
144 std::fs::write(path.join("vectors.json"), content)?;
145 }
146 Ok(())
147 }
148
149 pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
151 let embeddings = self.embed_batch(vec![text.to_string()]).await?;
152 Ok(embeddings[0].clone())
153 }
154
155 pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
157 if texts.is_empty() {
158 return Ok(Vec::new());
159 }
160
161 let url = format!(
162 "{}/{}/pipeline/feature-extraction",
163 HUGGINGFACE_API_URL, self.model
164 );
165
166 let mut request = self.client.post(&url).json(&serde_json::json!({
168 "inputs": texts,
169 }));
170
171 if let Some(ref token) = self.api_token {
173 request = request.header("Authorization", format!("Bearer {}", token));
174 }
175
176 let response = request.send().await?;
177
178 if !response.status().is_success() {
179 let error_text = response.text().await?;
180 anyhow::bail!("HuggingFace API error: {}", error_text);
181 }
182
183 let response_json: serde_json::Value = response.json().await?;
184 let mut results = Vec::new();
185
186 if let Some(arr) = response_json.as_array() {
187 for item in arr {
188 let vec: Vec<f32> = serde_json::from_value(item.clone())
189 .map_err(|e| anyhow::anyhow!("Failed to parse embedding: {}", e))?;
190
191 if vec.len() != self.dimensions {
192 anyhow::bail!("Expected {} dimensions, got {}", self.dimensions, vec.len());
193 }
194 results.push(vec);
195 }
196 } else {
197 if texts.len() == 1 {
199 if let Ok(vec) = serde_json::from_value::<Vec<f32>>(response_json) {
200 if vec.len() == self.dimensions {
201 results.push(vec);
202 }
203 }
204 }
205 }
206
207 if results.len() != texts.len() {
208 anyhow::bail!("Expected {} embeddings, got {}", texts.len(), results.len());
209 }
210
211 Ok(results)
212 }
213
214 pub async fn add(&self, uri: &str, content: &str) -> Result<usize> {
216 let results = self.add_batch(vec![(uri.to_string(), content.to_string())]).await?;
217 Ok(results[0])
218 }
219
220 pub async fn add_batch(&self, items: Vec<(String, String)>) -> Result<Vec<usize>> {
222 let mut new_items = Vec::new();
224 let mut result_ids = vec![0; items.len()];
225 let mut new_indices = Vec::new(); {
228 let uri_map = self.uri_to_id.read().unwrap();
229 for (i, (uri, content)) in items.iter().enumerate() {
230 if let Some(&id) = uri_map.get(uri) {
231 result_ids[i] = id;
232 } else {
233 new_items.push(content.clone());
234 new_indices.push(i);
235 }
236 }
237 }
238
239 if new_items.is_empty() {
240 return Ok(result_ids);
241 }
242
243 let embeddings = self.embed_batch(new_items).await?;
245
246 let mut ids_to_add = Vec::new();
247 let mut searcher = hnsw::Searcher::default();
248
249 {
251 let mut index = self.index.write().unwrap();
252 let mut uri_map = self.uri_to_id.write().unwrap();
253 let mut id_map = self.id_to_uri.write().unwrap();
254 let mut embs = self.embeddings.write().unwrap();
255
256 for (i, embedding) in embeddings.into_iter().enumerate() {
257 let original_idx = new_indices[i];
258 let uri = &items[original_idx].0;
259
260 if let Some(&id) = uri_map.get(uri) {
262 result_ids[original_idx] = id;
263 continue;
264 }
265
266 let id = index.insert(embedding.clone(), &mut searcher);
267 uri_map.insert(uri.clone(), id);
268 id_map.insert(id, uri.clone());
269
270 embs.push(VectorEntry {
271 uri: uri.clone(),
272 embedding: embedding,
273 });
274
275 result_ids[original_idx] = id;
276 ids_to_add.push(id);
277 }
278 }
279
280 if !ids_to_add.is_empty() {
281 let _ = self.save_vectors(); }
283
284 Ok(result_ids)
285 }
286
287 pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
289 let query_embedding = self.embed(query).await?;
291
292 let mut searcher = hnsw::Searcher::default();
294 let mut neighbors = vec![
295 space::Neighbor {
296 index: 0,
297 distance: u32::MAX
298 };
299 k
300 ];
301
302 let found_neighbors = {
303 let index = self.index.read().unwrap();
304 index.nearest(&query_embedding, 50, &mut searcher, &mut neighbors)
305 };
306
307 let id_map = self.id_to_uri.read().unwrap();
309 let results: Vec<SearchResult> = found_neighbors
310 .iter()
311 .filter_map(|neighbor| {
312 id_map.get(&neighbor.index).map(|uri| {
313 let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
314 SearchResult {
315 uri: uri.clone(),
316 score: 1.0 / (1.0 + score_f32),
317 content: uri.clone(),
318 }
319 })
320 })
321 .collect();
322
323 Ok(results)
324 }
325
326 pub fn get_uri(&self, id: usize) -> Option<String> {
327 self.id_to_uri.read().unwrap().get(&id).cloned()
328 }
329
330 pub fn get_id(&self, uri: &str) -> Option<usize> {
331 self.uri_to_id.read().unwrap().get(uri).copied()
332 }
333
334 pub fn len(&self) -> usize {
335 self.uri_to_id.read().unwrap().len()
336 }
337
338 pub fn is_empty(&self) -> bool {
339 self.len() == 0
340 }
341
342 pub fn compact(&self) -> Result<usize> {
344 let embeddings = self.embeddings.read().unwrap();
345 let current_uris: std::collections::HashSet<_> = self.uri_to_id.read().unwrap().keys().cloned().collect();
346
347 if current_uris.is_empty() && !embeddings.is_empty() {
350 return Ok(0);
352 }
353
354 let active_entries: Vec<_> = embeddings
356 .iter()
357 .filter(|e| current_uris.contains(&e.uri))
358 .cloned()
359 .collect();
360
361 let removed = embeddings.len() - active_entries.len();
362
363 if removed == 0 {
364 return Ok(0);
365 }
366
367 let mut new_index = hnsw::Hnsw::new(Euclidian);
369 let mut new_id_to_uri = std::collections::HashMap::new();
370 let mut new_uri_to_id = std::collections::HashMap::new();
371 let mut searcher = hnsw::Searcher::default();
372
373 for entry in &active_entries {
374 if entry.embedding.len() == self.dimensions {
375 let id = new_index.insert(entry.embedding.clone(), &mut searcher);
376 new_id_to_uri.insert(id, entry.uri.clone());
377 new_uri_to_id.insert(entry.uri.clone(), id);
378 }
379 }
380
381 *self.index.write().unwrap() = new_index;
383 *self.id_to_uri.write().unwrap() = new_id_to_uri;
384 *self.uri_to_id.write().unwrap() = new_uri_to_id;
385
386 drop(embeddings);
388 *self.embeddings.write().unwrap() = active_entries;
389
390 let _ = self.save_vectors();
391
392 Ok(removed)
393 }
394
395 pub fn remove(&self, uri: &str) -> bool {
397 let mut uri_map = self.uri_to_id.write().unwrap();
398 let mut id_map = self.id_to_uri.write().unwrap();
399
400 if let Some(id) = uri_map.remove(uri) {
401 id_map.remove(&id);
402 true
404 } else {
405 false
406 }
407 }
408
409 pub fn stats(&self) -> (usize, usize, usize) {
411 let embeddings_count = self.embeddings.read().unwrap().len();
412 let active_count = self.uri_to_id.read().unwrap().len();
413 let stale_count = embeddings_count.saturating_sub(active_count);
414 (active_count, stale_count, embeddings_count)
415 }
416}