1use anyhow::Result;
2use hnsw::Hnsw;
3use rand_pcg::Pcg64;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::{Arc, RwLock};
9
10const HUGGINGFACE_API_URL: &str = "https://router.huggingface.co/hf-inference/models";
11const DEFAULT_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2"; const DEFAULT_DIMENSIONS: usize = 384;
13
14#[derive(Default, Clone)]
16pub struct Euclidian;
17
18impl space::Metric<Vec<f32>> for Euclidian {
19 type Unit = u32;
20 fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
21 let len = a.len().min(b.len());
22 let mut dist_sq = 0.0;
23 for i in 0..len {
24 let diff = a[i] - b[i];
25 dist_sq += diff * diff;
26 }
27 (dist_sq.sqrt() * 1_000_000.0) as u32
29 }
30}
31
32#[derive(Serialize, Deserialize, Default)]
34struct VectorData {
35 entries: Vec<VectorEntry>,
36}
37
38#[derive(Serialize, Deserialize, Clone)]
39struct VectorEntry {
40 key: String,
42 embedding: Vec<f32>,
43 #[serde(default)]
45 metadata: serde_json::Value,
46}
47
48pub struct VectorStore {
50 index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
52 id_to_key: Arc<RwLock<HashMap<usize, String>>>,
54 key_to_id: Arc<RwLock<HashMap<String, usize>>>,
56 key_to_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
58 storage_path: Option<PathBuf>,
60 client: Client,
62 api_token: Option<String>,
64 model: String,
66 dimensions: usize,
68 embeddings: Arc<RwLock<Vec<VectorEntry>>>,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73pub struct SearchResult {
74 pub key: String,
76 pub score: f32,
77 pub metadata: serde_json::Value,
79 pub uri: String,
81}
82
83impl VectorStore {
84 pub fn new(namespace: &str) -> Result<Self> {
86 let storage_path = std::env::var("GRAPH_STORAGE_PATH")
88 .ok()
89 .map(|p| PathBuf::from(p).join(namespace));
90
91 let dimensions = std::env::var("VECTOR_DIMENSIONS")
93 .ok()
94 .and_then(|s| s.parse().ok())
95 .unwrap_or(DEFAULT_DIMENSIONS);
96
97 let mut index = Hnsw::new(Euclidian);
99 let mut id_to_key = HashMap::new();
100 let mut key_to_id = HashMap::new();
101 let mut key_to_metadata = HashMap::new();
102 let mut embeddings = Vec::new();
103
104 if let Some(ref path) = storage_path {
106 let vectors_path = path.join("vectors.json");
107 if vectors_path.exists() {
108 if let Ok(content) = std::fs::read_to_string(&vectors_path) {
109 if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
111 let mut searcher = hnsw::Searcher::default();
112 for entry in data.entries {
113 if entry.embedding.len() == dimensions {
114 let id = index.insert(entry.embedding.clone(), &mut searcher);
115 id_to_key.insert(id, entry.key.clone());
116 key_to_id.insert(entry.key.clone(), id);
117 key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
118 embeddings.push(entry);
119 }
120 }
121 eprintln!("Loaded {} vectors from disk (dim={})", embeddings.len(), dimensions);
122 } else {
123 #[derive(Serialize, Deserialize)]
125 struct OldVectorData { entries: Vec<OldVectorEntry> }
126 #[derive(Serialize, Deserialize)]
127 struct OldVectorEntry { uri: String, embedding: Vec<f32> }
128
129 if let Ok(old_data) = serde_json::from_str::<OldVectorData>(&content) {
130 let mut searcher = hnsw::Searcher::default();
131 for old in old_data.entries {
132 if old.embedding.len() == dimensions {
133 let id = index.insert(old.embedding.clone(), &mut searcher);
134 id_to_key.insert(id, old.uri.clone());
135 key_to_id.insert(old.uri.clone(), id);
136 let metadata = serde_json::json!({ "uri": old.uri });
137 key_to_metadata.insert(old.uri.clone(), metadata.clone());
138 embeddings.push(VectorEntry {
139 key: old.uri.clone(),
140 embedding: old.embedding,
141 metadata,
142 });
143 }
144 }
145 eprintln!("Loaded {} legacy vectors from disk (dim={})", embeddings.len(), dimensions);
146 }
147 }
148 }
149 }
150 }
151
152 let api_token = std::env::var("HUGGINGFACE_API_TOKEN").ok();
154
155 let client = Client::builder()
157 .timeout(std::time::Duration::from_secs(30))
158 .build()
159 .unwrap_or_else(|_| Client::new());
160
161 Ok(Self {
162 index: Arc::new(RwLock::new(index)),
163 id_to_key: Arc::new(RwLock::new(id_to_key)),
164 key_to_id: Arc::new(RwLock::new(key_to_id)),
165 key_to_metadata: Arc::new(RwLock::new(key_to_metadata)),
166 storage_path,
167 client,
168 api_token,
169 model: DEFAULT_MODEL.to_string(),
170 dimensions,
171 embeddings: Arc::new(RwLock::new(embeddings)),
172 })
173 }
174
175 fn save_vectors(&self) -> Result<()> {
177 if let Some(ref path) = self.storage_path {
178 std::fs::create_dir_all(path)?;
179 let data = VectorData {
180 entries: self.embeddings.read().unwrap().clone(),
181 };
182 let content = serde_json::to_string(&data)?;
183 std::fs::write(path.join("vectors.json"), content)?;
184 }
185 Ok(())
186 }
187
188 pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
191 if std::env::var("MOCK_EMBEDDINGS").is_ok() {
192 use rand::Rng;
194 let mut rng = rand::rng();
195 let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
196 return Ok(vec);
197 }
198
199 let embeddings = self.embed_batch(vec![text.to_string()]).await?;
200 Ok(embeddings[0].clone())
201 }
202
203 pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
205 if texts.is_empty() {
206 return Ok(Vec::new());
207 }
208
209 if std::env::var("MOCK_EMBEDDINGS").is_ok() {
210 use rand::Rng;
211 let mut rng = rand::rng();
212 let mut results = Vec::new();
213 for _ in 0..texts.len() {
214 let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
215 results.push(vec);
216 }
217 return Ok(results);
218 }
219
220 let url = format!(
221 "{}/{}/pipeline/feature-extraction",
222 HUGGINGFACE_API_URL, self.model
223 );
224
225 let mut request = self.client.post(&url).json(&serde_json::json!({
227 "inputs": texts,
228 }));
229
230 if let Some(ref token) = self.api_token {
232 request = request.header("Authorization", format!("Bearer {}", token));
233 }
234
235 let response = request.send().await?;
236
237 if !response.status().is_success() {
238 let error_text = response.text().await?;
239 anyhow::bail!("HuggingFace API error: {}", error_text);
240 }
241
242 let response_json: serde_json::Value = response.json().await?;
243 let mut results = Vec::new();
244
245 if let Some(arr) = response_json.as_array() {
246 for item in arr {
247 let vec: Vec<f32> = serde_json::from_value(item.clone())
248 .map_err(|e| anyhow::anyhow!("Failed to parse embedding: {}", e))?;
249
250 if vec.len() != self.dimensions {
251 anyhow::bail!("Expected {} dimensions, got {}", self.dimensions, vec.len());
252 }
253 results.push(vec);
254 }
255 } else {
256 if texts.len() == 1 {
258 if let Ok(vec) = serde_json::from_value::<Vec<f32>>(response_json) {
259 if vec.len() == self.dimensions {
260 results.push(vec);
261 }
262 }
263 }
264 }
265
266 if results.len() != texts.len() {
267 anyhow::bail!("Expected {} embeddings, got {}", texts.len(), results.len());
268 }
269
270 Ok(results)
271 }
272
273 pub async fn add(&self, key: &str, content: &str, metadata: serde_json::Value) -> Result<usize> {
275 let results = self.add_batch(vec![(key.to_string(), content.to_string(), metadata)]).await?;
276 Ok(results[0])
277 }
278
279 pub async fn add_batch(&self, items: Vec<(String, String, serde_json::Value)>) -> Result<Vec<usize>> {
281 let mut new_items = Vec::new();
283 let mut result_ids = vec![0; items.len()];
284 let mut new_indices = Vec::new(); {
287 let key_map = self.key_to_id.read().unwrap();
288 for (i, (key, content, _)) in items.iter().enumerate() {
289 if let Some(&id) = key_map.get(key) {
290 result_ids[i] = id;
291 } else {
292 new_items.push(content.clone());
293 new_indices.push(i);
294 }
295 }
296 }
297
298 if new_items.is_empty() {
299 return Ok(result_ids);
300 }
301
302 let embeddings = self.embed_batch(new_items).await?;
304
305 let mut ids_to_add = Vec::new();
306 let mut searcher = hnsw::Searcher::default();
307
308 {
310 let mut index = self.index.write().unwrap();
311 let mut key_map = self.key_to_id.write().unwrap();
312 let mut id_map = self.id_to_key.write().unwrap();
313 let mut metadata_map = self.key_to_metadata.write().unwrap();
314 let mut embs = self.embeddings.write().unwrap();
315
316 for (i, embedding) in embeddings.into_iter().enumerate() {
317 let original_idx = new_indices[i];
318 let (key, _, metadata) = &items[original_idx];
319
320 if let Some(&id) = key_map.get(key) {
322 result_ids[original_idx] = id;
323 continue;
324 }
325
326 let id = index.insert(embedding.clone(), &mut searcher);
327 key_map.insert(key.clone(), id);
328 id_map.insert(id, key.clone());
329 metadata_map.insert(key.clone(), metadata.clone());
330
331 embs.push(VectorEntry {
332 key: key.clone(),
333 embedding: embedding,
334 metadata: metadata.clone(),
335 });
336
337 result_ids[original_idx] = id;
338 ids_to_add.push(id);
339 }
340 }
341
342 if !ids_to_add.is_empty() {
343 let _ = self.save_vectors(); }
345
346 Ok(result_ids)
347 }
348
349 pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
351 let query_embedding = self.embed(query).await?;
353
354 let mut searcher = hnsw::Searcher::default();
356 let mut neighbors = vec![
357 space::Neighbor {
358 index: 0,
359 distance: u32::MAX
360 };
361 k
362 ];
363
364 let found_neighbors = {
365 let index = self.index.read().unwrap();
366 index.nearest(&query_embedding, 50, &mut searcher, &mut neighbors)
367 };
368
369 let id_map = self.id_to_key.read().unwrap();
371 let metadata_map = self.key_to_metadata.read().unwrap();
372
373 let results: Vec<SearchResult> = found_neighbors
374 .iter()
375 .filter_map(|neighbor| {
376 id_map.get(&neighbor.index).map(|key| {
377 let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
378
379 let metadata = metadata_map.get(key).cloned().unwrap_or(serde_json::Value::Null);
380
381 let uri = metadata.get("uri").and_then(|v| v.as_str()).unwrap_or(key).to_string();
382
383 SearchResult {
384 key: key.clone(),
385 score: 1.0 / (1.0 + score_f32),
386 metadata,
387 uri,
388 }
389 })
390 })
391 .collect();
392
393 Ok(results)
394 }
395
396 pub fn get_key(&self, id: usize) -> Option<String> {
397 self.id_to_key.read().unwrap().get(&id).cloned()
398 }
399
400 pub fn get_id(&self, key: &str) -> Option<usize> {
401 self.key_to_id.read().unwrap().get(key).copied()
402 }
403
404 pub fn len(&self) -> usize {
405 self.key_to_id.read().unwrap().len()
406 }
407
408 pub fn is_empty(&self) -> bool {
409 self.len() == 0
410 }
411
412 pub fn compact(&self) -> Result<usize> {
414 let embeddings = self.embeddings.read().unwrap();
415 let current_keys: std::collections::HashSet<_> = self.key_to_id.read().unwrap().keys().cloned().collect();
416
417 if current_keys.is_empty() && !embeddings.is_empty() {
418 return Ok(0);
419 }
420
421 let active_entries: Vec<_> = embeddings
423 .iter()
424 .filter(|e| current_keys.contains(&e.key))
425 .cloned()
426 .collect();
427
428 let removed = embeddings.len() - active_entries.len();
429
430 if removed == 0 {
431 return Ok(0);
432 }
433
434 let mut new_index = hnsw::Hnsw::new(Euclidian);
436 let mut new_id_to_key = std::collections::HashMap::new();
437 let mut new_key_to_id = std::collections::HashMap::new();
438 let mut new_key_to_metadata = std::collections::HashMap::new();
439 let mut searcher = hnsw::Searcher::default();
440
441 for entry in &active_entries {
442 if entry.embedding.len() == self.dimensions {
443 let id = new_index.insert(entry.embedding.clone(), &mut searcher);
444 new_id_to_key.insert(id, entry.key.clone());
445 new_key_to_id.insert(entry.key.clone(), id);
446 new_key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
447 }
448 }
449
450 *self.index.write().unwrap() = new_index;
452 *self.id_to_key.write().unwrap() = new_id_to_key;
453 *self.key_to_id.write().unwrap() = new_key_to_id;
454 *self.key_to_metadata.write().unwrap() = new_key_to_metadata;
455
456 drop(embeddings);
458 *self.embeddings.write().unwrap() = active_entries;
459
460 let _ = self.save_vectors();
461
462 Ok(removed)
463 }
464
465 pub fn remove(&self, key: &str) -> bool {
467 let mut key_map = self.key_to_id.write().unwrap();
468 let mut id_map = self.id_to_key.write().unwrap();
469 let mut metadata_map = self.key_to_metadata.write().unwrap();
470
471 if let Some(id) = key_map.remove(key) {
472 id_map.remove(&id);
473 metadata_map.remove(key);
474 true
476 } else {
477 false
478 }
479 }
480
481 pub fn stats(&self) -> (usize, usize, usize) {
483 let embeddings_count = self.embeddings.read().unwrap().len();
484 let active_count = self.key_to_id.read().unwrap().len();
485 let stale_count = embeddings_count.saturating_sub(active_count);
486 (active_count, stale_count, embeddings_count)
487 }
488}