1use crate::persistence::{load_bincode, save_bincode};
2use anyhow::Result;
3use hnsw::Hnsw;
4use rand_pcg::Pcg64;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::{Arc, RwLock};
11
12const HUGGINGFACE_API_URL: &str = "https://router.huggingface.co/hf-inference/models";
13const DEFAULT_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2"; const DEFAULT_DIMENSIONS: usize = 384;
15const DEFAULT_AUTO_SAVE_THRESHOLD: usize = 100;
16
17#[derive(Default, Clone)]
19pub struct Euclidian;
20
21impl space::Metric<Vec<f32>> for Euclidian {
22 type Unit = u32;
23 fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
24 let len = a.len().min(b.len());
25 let mut dist_sq = 0.0;
26 for i in 0..len {
27 let diff = a[i] - b[i];
28 dist_sq += diff * diff;
29 }
30 (dist_sq.sqrt() * 1_000_000.0) as u32
32 }
33}
34
35#[derive(Serialize, Deserialize, Default)]
37struct VectorData {
38 entries: Vec<VectorEntry>,
39}
40
41#[derive(Serialize, Deserialize, Clone)]
42struct VectorEntry {
43 key: String,
45 embedding: Vec<f32>,
46 #[serde(default)]
48 metadata: serde_json::Value,
49}
50
51pub struct VectorStore {
53 index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
55 id_to_key: Arc<RwLock<HashMap<usize, String>>>,
57 key_to_id: Arc<RwLock<HashMap<String, usize>>>,
59 key_to_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
61 storage_path: Option<PathBuf>,
63 client: Client,
65 api_url: String,
67 api_token: Option<String>,
69 model: String,
71 dimensions: usize,
73 embeddings: Arc<RwLock<Vec<VectorEntry>>>,
75 dirty_count: Arc<AtomicUsize>,
77 auto_save_threshold: usize,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82pub struct SearchResult {
83 pub key: String,
85 pub score: f32,
86 pub metadata: serde_json::Value,
88 pub uri: String,
90}
91
92impl VectorStore {
93 pub fn new(namespace: &str) -> Result<Self> {
95 let storage_path = std::env::var("GRAPH_STORAGE_PATH")
97 .ok()
98 .map(|p| PathBuf::from(p).join(namespace));
99
100 let dimensions = std::env::var("VECTOR_DIMENSIONS")
102 .ok()
103 .and_then(|s| s.parse().ok())
104 .unwrap_or(DEFAULT_DIMENSIONS);
105
106 let mut index = Hnsw::new(Euclidian);
108 let mut id_to_key = HashMap::new();
109 let mut key_to_id = HashMap::new();
110 let mut key_to_metadata = HashMap::new();
111 let mut embeddings = Vec::new();
112
113 if let Some(ref path) = storage_path {
115 let vectors_bin = path.join("vectors.bin");
116 let vectors_json = path.join("vectors.json");
117
118 let loaded_data = if vectors_bin.exists() {
119 load_bincode::<VectorData>(&vectors_bin).ok()
120 } else if vectors_json.exists() {
121 let content = std::fs::read_to_string(&vectors_json).ok();
123 if let Some(content) = content {
124 if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
126 Some(data)
127 } else {
128 #[derive(Serialize, Deserialize)]
130 struct OldVectorData {
131 entries: Vec<OldVectorEntry>,
132 }
133 #[derive(Serialize, Deserialize)]
134 struct OldVectorEntry {
135 uri: String,
136 embedding: Vec<f32>,
137 }
138
139 if let Ok(old_data) = serde_json::from_str::<OldVectorData>(&content) {
140 let entries = old_data
141 .entries
142 .into_iter()
143 .map(|old| VectorEntry {
144 key: old.uri.clone(),
145 embedding: old.embedding,
146 metadata: serde_json::json!({ "uri": old.uri }),
147 })
148 .collect();
149 Some(VectorData { entries })
150 } else {
151 None
152 }
153 }
154 } else {
155 None
156 }
157 } else {
158 None
159 };
160
161 if let Some(data) = loaded_data {
162 let mut searcher = hnsw::Searcher::default();
163 for entry in data.entries {
164 if entry.embedding.len() == dimensions {
165 let id = index.insert(entry.embedding.clone(), &mut searcher);
166 id_to_key.insert(id, entry.key.clone());
167 key_to_id.insert(entry.key.clone(), id);
168 key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
169 embeddings.push(entry);
170 }
171 }
172 eprintln!(
173 "Loaded {} vectors from disk (dim={})",
174 embeddings.len(),
175 dimensions
176 );
177 }
178 }
179
180 let api_token = std::env::var("HUGGINGFACE_API_TOKEN").ok();
182
183 let api_url = std::env::var("HUGGINGFACE_API_URL")
185 .unwrap_or_else(|_| HUGGINGFACE_API_URL.to_string());
186
187 let client = Client::builder()
189 .timeout(std::time::Duration::from_secs(30))
190 .build()
191 .unwrap_or_else(|_| Client::new());
192
193 Ok(Self {
194 index: Arc::new(RwLock::new(index)),
195 id_to_key: Arc::new(RwLock::new(id_to_key)),
196 key_to_id: Arc::new(RwLock::new(key_to_id)),
197 key_to_metadata: Arc::new(RwLock::new(key_to_metadata)),
198 storage_path,
199 client,
200 api_url,
201 api_token,
202 model: DEFAULT_MODEL.to_string(),
203 dimensions,
204 embeddings: Arc::new(RwLock::new(embeddings)),
205 dirty_count: Arc::new(AtomicUsize::new(0)),
206 auto_save_threshold: DEFAULT_AUTO_SAVE_THRESHOLD,
207 })
208 }
209
210 fn save_vectors(&self) -> Result<()> {
212 if let Some(ref path) = self.storage_path {
213 std::fs::create_dir_all(path)?;
214
215 let (entries, current_dirty) = {
237 let guard = self.embeddings.read().unwrap();
238 (guard.clone(), self.dirty_count.load(Ordering::Relaxed))
239 };
240
241 let data = VectorData { entries };
242 save_bincode(&path.join("vectors.bin"), &data)?;
243
244 if current_dirty > 0 {
245 let _ = self.dirty_count.fetch_sub(current_dirty, Ordering::Relaxed);
246 }
247 }
248 Ok(())
249 }
250
251 pub fn flush(&self) -> Result<()> {
253 self.save_vectors()
254 }
255
256 pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
259 if std::env::var("MOCK_EMBEDDINGS").is_ok() {
260 use rand::Rng;
262 let mut rng = rand::rng();
263 let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
264 return Ok(vec);
265 }
266
267 let embeddings = self.embed_batch(vec![text.to_string()]).await?;
268 Ok(embeddings[0].clone())
269 }
270
271 pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
273 if texts.is_empty() {
274 return Ok(Vec::new());
275 }
276
277 if std::env::var("MOCK_EMBEDDINGS").is_ok() {
278 use rand::Rng;
279 let mut rng = rand::rng();
280 let mut results = Vec::new();
281 for _ in 0..texts.len() {
282 let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
283 results.push(vec);
284 }
285 return Ok(results);
286 }
287
288 let url = format!(
289 "{}/{}/pipeline/feature-extraction",
290 self.api_url, self.model
291 );
292
293 let mut request = self.client.post(&url).json(&serde_json::json!({
295 "inputs": texts,
296 }));
297
298 if let Some(ref token) = self.api_token {
300 request = request.header("Authorization", format!("Bearer {}", token));
301 }
302
303 let response = request.send().await?;
304
305 if !response.status().is_success() {
306 let error_text = response.text().await?;
307 anyhow::bail!("HuggingFace API error: {}", error_text);
308 }
309
310 let response_json: serde_json::Value = response.json().await?;
311 let mut results = Vec::new();
312
313 if let Some(arr) = response_json.as_array() {
314 for item in arr {
315 let vec: Vec<f32> = serde_json::from_value(item.clone())
316 .map_err(|e| anyhow::anyhow!("Failed to parse embedding: {}", e))?;
317
318 if vec.len() != self.dimensions {
319 anyhow::bail!("Expected {} dimensions, got {}", self.dimensions, vec.len());
320 }
321 results.push(vec);
322 }
323 } else {
324 if texts.len() == 1 {
326 if let Ok(vec) = serde_json::from_value::<Vec<f32>>(response_json) {
327 if vec.len() == self.dimensions {
328 results.push(vec);
329 }
330 }
331 }
332 }
333
334 if results.len() != texts.len() {
335 anyhow::bail!("Expected {} embeddings, got {}", texts.len(), results.len());
336 }
337
338 Ok(results)
339 }
340
341 pub async fn add(
343 &self,
344 key: &str,
345 content: &str,
346 metadata: serde_json::Value,
347 ) -> Result<usize> {
348 let results = self
349 .add_batch(vec![(key.to_string(), content.to_string(), metadata)])
350 .await?;
351 Ok(results[0])
352 }
353
354 pub async fn add_batch(
356 &self,
357 items: Vec<(String, String, serde_json::Value)>,
358 ) -> Result<Vec<usize>> {
359 let mut new_items = Vec::new();
361 let mut result_ids = vec![0; items.len()];
362 let mut new_indices = Vec::new(); {
365 let key_map = self.key_to_id.read().unwrap();
366 for (i, (key, content, _)) in items.iter().enumerate() {
367 if let Some(&id) = key_map.get(key) {
368 result_ids[i] = id;
369 } else {
370 new_items.push(content.clone());
371 new_indices.push(i);
372 }
373 }
374 }
375
376 if new_items.is_empty() {
377 return Ok(result_ids);
378 }
379
380 let embeddings = self.embed_batch(new_items).await?;
382
383 let mut ids_to_add = Vec::new();
384 let mut searcher = hnsw::Searcher::default();
385
386 {
388 let mut index = self.index.write().unwrap();
389 let mut key_map = self.key_to_id.write().unwrap();
390 let mut id_map = self.id_to_key.write().unwrap();
391 let mut metadata_map = self.key_to_metadata.write().unwrap();
392 let mut embs = self.embeddings.write().unwrap();
393
394 for (i, embedding) in embeddings.into_iter().enumerate() {
395 let original_idx = new_indices[i];
396 let (key, _, metadata) = &items[original_idx];
397
398 if let Some(&id) = key_map.get(key) {
400 result_ids[original_idx] = id;
401 continue;
402 }
403
404 let id = index.insert(embedding.clone(), &mut searcher);
405 key_map.insert(key.clone(), id);
406 id_map.insert(id, key.clone());
407 metadata_map.insert(key.clone(), metadata.clone());
408
409 embs.push(VectorEntry {
410 key: key.clone(),
411 embedding,
412 metadata: metadata.clone(),
413 });
414
415 result_ids[original_idx] = id;
416 ids_to_add.push(id);
417 }
418 }
419
420 if !ids_to_add.is_empty() {
421 let count = self
422 .dirty_count
423 .fetch_add(ids_to_add.len(), Ordering::Relaxed);
424 if count + ids_to_add.len() >= self.auto_save_threshold {
425 let _ = self.save_vectors();
426 }
427 }
428
429 Ok(result_ids)
430 }
431
432 pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
434 let query_embedding = self.embed(query).await?;
436
437 let mut searcher = hnsw::Searcher::default();
439
440 let index = self.index.read().unwrap();
441 let len = index.len();
442 if len == 0 {
443 return Ok(Vec::new());
444 }
445
446 let k = k.min(len);
447 let ef = k.max(50); let mut neighbors = vec![
450 space::Neighbor {
451 index: 0,
452 distance: u32::MAX
453 };
454 k
455 ];
456
457 let found_neighbors = index.nearest(&query_embedding, ef, &mut searcher, &mut neighbors);
458
459 let id_map = self.id_to_key.read().unwrap();
461 let metadata_map = self.key_to_metadata.read().unwrap();
462
463 let results: Vec<SearchResult> = found_neighbors
464 .iter()
465 .filter_map(|neighbor| {
466 id_map.get(&neighbor.index).map(|key| {
467 let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
468
469 let metadata = metadata_map
470 .get(key)
471 .cloned()
472 .unwrap_or(serde_json::Value::Null);
473
474 let uri = metadata
475 .get("uri")
476 .and_then(|v| v.as_str())
477 .unwrap_or(key)
478 .to_string();
479
480 SearchResult {
481 key: key.clone(),
482 score: 1.0 / (1.0 + score_f32),
483 metadata,
484 uri,
485 }
486 })
487 })
488 .collect();
489
490 Ok(results)
491 }
492
493 pub fn get_key(&self, id: usize) -> Option<String> {
494 self.id_to_key.read().unwrap().get(&id).cloned()
495 }
496
497 pub fn get_id(&self, key: &str) -> Option<usize> {
498 self.key_to_id.read().unwrap().get(key).copied()
499 }
500
501 pub fn len(&self) -> usize {
502 self.key_to_id.read().unwrap().len()
503 }
504
505 pub fn is_empty(&self) -> bool {
506 self.len() == 0
507 }
508
509 pub fn compact(&self) -> Result<usize> {
511 let embeddings = self.embeddings.read().unwrap();
512 let current_keys: std::collections::HashSet<_> =
513 self.key_to_id.read().unwrap().keys().cloned().collect();
514
515 if current_keys.is_empty() && !embeddings.is_empty() {
516 return Ok(0);
517 }
518
519 let active_entries: Vec<_> = embeddings
521 .iter()
522 .filter(|e| current_keys.contains(&e.key))
523 .cloned()
524 .collect();
525
526 let removed = embeddings.len() - active_entries.len();
527
528 if removed == 0 {
529 return Ok(0);
530 }
531
532 let mut new_index = hnsw::Hnsw::new(Euclidian);
534 let mut new_id_to_key = std::collections::HashMap::new();
535 let mut new_key_to_id = std::collections::HashMap::new();
536 let mut new_key_to_metadata = std::collections::HashMap::new();
537 let mut searcher = hnsw::Searcher::default();
538
539 for entry in &active_entries {
540 if entry.embedding.len() == self.dimensions {
541 let id = new_index.insert(entry.embedding.clone(), &mut searcher);
542 new_id_to_key.insert(id, entry.key.clone());
543 new_key_to_id.insert(entry.key.clone(), id);
544 new_key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
545 }
546 }
547
548 *self.index.write().unwrap() = new_index;
550 *self.id_to_key.write().unwrap() = new_id_to_key;
551 *self.key_to_id.write().unwrap() = new_key_to_id;
552 *self.key_to_metadata.write().unwrap() = new_key_to_metadata;
553
554 drop(embeddings);
556 *self.embeddings.write().unwrap() = active_entries;
557
558 let _ = self.save_vectors();
559
560 Ok(removed)
561 }
562
563 pub fn remove(&self, key: &str) -> bool {
565 let mut key_map = self.key_to_id.write().unwrap();
566 let mut id_map = self.id_to_key.write().unwrap();
567 let mut metadata_map = self.key_to_metadata.write().unwrap();
568
569 if let Some(id) = key_map.remove(key) {
570 id_map.remove(&id);
571 metadata_map.remove(key);
572 true
574 } else {
575 false
576 }
577 }
578
579 pub fn stats(&self) -> (usize, usize, usize) {
581 let embeddings_count = self.embeddings.read().unwrap().len();
582 let active_count = self.key_to_id.read().unwrap().len();
583 let stale_count = embeddings_count.saturating_sub(active_count);
584 (active_count, stale_count, embeddings_count)
585 }
586}