1use crate::persistence::{load_bincode, save_bincode};
2use anyhow::Result;
3use hnsw::Hnsw;
4use rand_pcg::Pcg64;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::{Arc, RwLock};
11
12const HUGGINGFACE_API_URL: &str = "https://router.huggingface.co/hf-inference/models";
13const DEFAULT_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2"; const DEFAULT_DIMENSIONS: usize = 384;
15const DEFAULT_AUTO_SAVE_THRESHOLD: usize = 100;
16
17#[derive(Default, Clone)]
19pub struct Euclidian;
20
21impl space::Metric<Vec<f32>> for Euclidian {
22 type Unit = u32;
23 fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
24 let len = a.len().min(b.len());
25 let mut dist_sq = 0.0;
26 for i in 0..len {
27 let diff = a[i] - b[i];
28 dist_sq += diff * diff;
29 }
30 (dist_sq.sqrt() * 1_000_000.0) as u32
32 }
33}
34
35#[derive(Serialize, Deserialize, Default)]
37struct VectorData {
38 entries: Vec<VectorEntry>,
39}
40
41#[derive(Serialize, Deserialize, Clone)]
42struct VectorEntry {
43 key: String,
45 embedding: Vec<f32>,
46 #[serde(default)]
48 metadata: serde_json::Value,
49}
50
51pub struct VectorStore {
53 index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
55 id_to_key: Arc<RwLock<HashMap<usize, String>>>,
57 key_to_id: Arc<RwLock<HashMap<String, usize>>>,
59 key_to_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
61 storage_path: Option<PathBuf>,
63 client: Client,
65 api_token: Option<String>,
67 model: String,
69 dimensions: usize,
71 embeddings: Arc<RwLock<Vec<VectorEntry>>>,
73 dirty_count: Arc<AtomicUsize>,
75 auto_save_threshold: usize,
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80pub struct SearchResult {
81 pub key: String,
83 pub score: f32,
84 pub metadata: serde_json::Value,
86 pub uri: String,
88}
89
90impl VectorStore {
91 pub fn new(namespace: &str) -> Result<Self> {
93 let storage_path = std::env::var("GRAPH_STORAGE_PATH")
95 .ok()
96 .map(|p| PathBuf::from(p).join(namespace));
97
98 let dimensions = std::env::var("VECTOR_DIMENSIONS")
100 .ok()
101 .and_then(|s| s.parse().ok())
102 .unwrap_or(DEFAULT_DIMENSIONS);
103
104 let mut index = Hnsw::new(Euclidian);
106 let mut id_to_key = HashMap::new();
107 let mut key_to_id = HashMap::new();
108 let mut key_to_metadata = HashMap::new();
109 let mut embeddings = Vec::new();
110
111 if let Some(ref path) = storage_path {
113 let vectors_bin = path.join("vectors.bin");
114 let vectors_json = path.join("vectors.json");
115
116 let loaded_data = if vectors_bin.exists() {
117 load_bincode::<VectorData>(&vectors_bin).ok()
118 } else if vectors_json.exists() {
119 let content = std::fs::read_to_string(&vectors_json).ok();
121 if let Some(content) = content {
122 if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
124 Some(data)
125 } else {
126 #[derive(Serialize, Deserialize)]
128 struct OldVectorData {
129 entries: Vec<OldVectorEntry>,
130 }
131 #[derive(Serialize, Deserialize)]
132 struct OldVectorEntry {
133 uri: String,
134 embedding: Vec<f32>,
135 }
136
137 if let Ok(old_data) = serde_json::from_str::<OldVectorData>(&content) {
138 let entries = old_data
139 .entries
140 .into_iter()
141 .map(|old| VectorEntry {
142 key: old.uri.clone(),
143 embedding: old.embedding,
144 metadata: serde_json::json!({ "uri": old.uri }),
145 })
146 .collect();
147 Some(VectorData { entries })
148 } else {
149 None
150 }
151 }
152 } else {
153 None
154 }
155 } else {
156 None
157 };
158
159 if let Some(data) = loaded_data {
160 let mut searcher = hnsw::Searcher::default();
161 for entry in data.entries {
162 if entry.embedding.len() == dimensions {
163 let id = index.insert(entry.embedding.clone(), &mut searcher);
164 id_to_key.insert(id, entry.key.clone());
165 key_to_id.insert(entry.key.clone(), id);
166 key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
167 embeddings.push(entry);
168 }
169 }
170 eprintln!(
171 "Loaded {} vectors from disk (dim={})",
172 embeddings.len(),
173 dimensions
174 );
175 }
176 }
177
178 let api_token = std::env::var("HUGGINGFACE_API_TOKEN").ok();
180
181 let client = Client::builder()
183 .timeout(std::time::Duration::from_secs(30))
184 .build()
185 .unwrap_or_else(|_| Client::new());
186
187 Ok(Self {
188 index: Arc::new(RwLock::new(index)),
189 id_to_key: Arc::new(RwLock::new(id_to_key)),
190 key_to_id: Arc::new(RwLock::new(key_to_id)),
191 key_to_metadata: Arc::new(RwLock::new(key_to_metadata)),
192 storage_path,
193 client,
194 api_token,
195 model: DEFAULT_MODEL.to_string(),
196 dimensions,
197 embeddings: Arc::new(RwLock::new(embeddings)),
198 dirty_count: Arc::new(AtomicUsize::new(0)),
199 auto_save_threshold: DEFAULT_AUTO_SAVE_THRESHOLD,
200 })
201 }
202
203 fn save_vectors(&self) -> Result<()> {
205 if let Some(ref path) = self.storage_path {
206 std::fs::create_dir_all(path)?;
207
208 let (entries, current_dirty) = {
230 let guard = self.embeddings.read().unwrap();
231 (guard.clone(), self.dirty_count.load(Ordering::Relaxed))
232 };
233
234 let data = VectorData { entries };
235 save_bincode(&path.join("vectors.bin"), &data)?;
236
237 if current_dirty > 0 {
238 let _ = self.dirty_count.fetch_sub(current_dirty, Ordering::Relaxed);
239 }
240 }
241 Ok(())
242 }
243
244 pub fn flush(&self) -> Result<()> {
246 self.save_vectors()
247 }
248
249 pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
252 if std::env::var("MOCK_EMBEDDINGS").is_ok() {
253 use rand::Rng;
255 let mut rng = rand::rng();
256 let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
257 return Ok(vec);
258 }
259
260 let embeddings = self.embed_batch(vec![text.to_string()]).await?;
261 Ok(embeddings[0].clone())
262 }
263
264 pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
266 if texts.is_empty() {
267 return Ok(Vec::new());
268 }
269
270 if std::env::var("MOCK_EMBEDDINGS").is_ok() {
271 use rand::Rng;
272 let mut rng = rand::rng();
273 let mut results = Vec::new();
274 for _ in 0..texts.len() {
275 let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
276 results.push(vec);
277 }
278 return Ok(results);
279 }
280
281 let url = format!(
282 "{}/{}/pipeline/feature-extraction",
283 HUGGINGFACE_API_URL, self.model
284 );
285
286 let mut request = self.client.post(&url).json(&serde_json::json!({
288 "inputs": texts,
289 }));
290
291 if let Some(ref token) = self.api_token {
293 request = request.header("Authorization", format!("Bearer {}", token));
294 }
295
296 let response = request.send().await?;
297
298 if !response.status().is_success() {
299 let error_text = response.text().await?;
300 anyhow::bail!("HuggingFace API error: {}", error_text);
301 }
302
303 let response_json: serde_json::Value = response.json().await?;
304 let mut results = Vec::new();
305
306 if let Some(arr) = response_json.as_array() {
307 for item in arr {
308 let vec: Vec<f32> = serde_json::from_value(item.clone())
309 .map_err(|e| anyhow::anyhow!("Failed to parse embedding: {}", e))?;
310
311 if vec.len() != self.dimensions {
312 anyhow::bail!("Expected {} dimensions, got {}", self.dimensions, vec.len());
313 }
314 results.push(vec);
315 }
316 } else {
317 if texts.len() == 1 {
319 if let Ok(vec) = serde_json::from_value::<Vec<f32>>(response_json) {
320 if vec.len() == self.dimensions {
321 results.push(vec);
322 }
323 }
324 }
325 }
326
327 if results.len() != texts.len() {
328 anyhow::bail!("Expected {} embeddings, got {}", texts.len(), results.len());
329 }
330
331 Ok(results)
332 }
333
334 pub async fn add(
336 &self,
337 key: &str,
338 content: &str,
339 metadata: serde_json::Value,
340 ) -> Result<usize> {
341 let results = self
342 .add_batch(vec![(key.to_string(), content.to_string(), metadata)])
343 .await?;
344 Ok(results[0])
345 }
346
347 pub async fn add_batch(
349 &self,
350 items: Vec<(String, String, serde_json::Value)>,
351 ) -> Result<Vec<usize>> {
352 let mut new_items = Vec::new();
354 let mut result_ids = vec![0; items.len()];
355 let mut new_indices = Vec::new(); {
358 let key_map = self.key_to_id.read().unwrap();
359 for (i, (key, content, _)) in items.iter().enumerate() {
360 if let Some(&id) = key_map.get(key) {
361 result_ids[i] = id;
362 } else {
363 new_items.push(content.clone());
364 new_indices.push(i);
365 }
366 }
367 }
368
369 if new_items.is_empty() {
370 return Ok(result_ids);
371 }
372
373 let embeddings = self.embed_batch(new_items).await?;
375
376 let mut ids_to_add = Vec::new();
377 let mut searcher = hnsw::Searcher::default();
378
379 {
381 let mut index = self.index.write().unwrap();
382 let mut key_map = self.key_to_id.write().unwrap();
383 let mut id_map = self.id_to_key.write().unwrap();
384 let mut metadata_map = self.key_to_metadata.write().unwrap();
385 let mut embs = self.embeddings.write().unwrap();
386
387 for (i, embedding) in embeddings.into_iter().enumerate() {
388 let original_idx = new_indices[i];
389 let (key, _, metadata) = &items[original_idx];
390
391 if let Some(&id) = key_map.get(key) {
393 result_ids[original_idx] = id;
394 continue;
395 }
396
397 let id = index.insert(embedding.clone(), &mut searcher);
398 key_map.insert(key.clone(), id);
399 id_map.insert(id, key.clone());
400 metadata_map.insert(key.clone(), metadata.clone());
401
402 embs.push(VectorEntry {
403 key: key.clone(),
404 embedding,
405 metadata: metadata.clone(),
406 });
407
408 result_ids[original_idx] = id;
409 ids_to_add.push(id);
410 }
411 }
412
413 if !ids_to_add.is_empty() {
414 let count = self
415 .dirty_count
416 .fetch_add(ids_to_add.len(), Ordering::Relaxed);
417 if count + ids_to_add.len() >= self.auto_save_threshold {
418 let _ = self.save_vectors();
419 }
420 }
421
422 Ok(result_ids)
423 }
424
425 pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
427 let query_embedding = self.embed(query).await?;
429
430 let mut searcher = hnsw::Searcher::default();
432
433 let index = self.index.read().unwrap();
434 let len = index.len();
435 if len == 0 {
436 return Ok(Vec::new());
437 }
438
439 let k = k.min(len);
440 let ef = k.max(50); let mut neighbors = vec![
443 space::Neighbor {
444 index: 0,
445 distance: u32::MAX
446 };
447 k
448 ];
449
450 let found_neighbors = index.nearest(&query_embedding, ef, &mut searcher, &mut neighbors);
451
452 let id_map = self.id_to_key.read().unwrap();
454 let metadata_map = self.key_to_metadata.read().unwrap();
455
456 let results: Vec<SearchResult> = found_neighbors
457 .iter()
458 .filter_map(|neighbor| {
459 id_map.get(&neighbor.index).map(|key| {
460 let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
461
462 let metadata = metadata_map
463 .get(key)
464 .cloned()
465 .unwrap_or(serde_json::Value::Null);
466
467 let uri = metadata
468 .get("uri")
469 .and_then(|v| v.as_str())
470 .unwrap_or(key)
471 .to_string();
472
473 SearchResult {
474 key: key.clone(),
475 score: 1.0 / (1.0 + score_f32),
476 metadata,
477 uri,
478 }
479 })
480 })
481 .collect();
482
483 Ok(results)
484 }
485
486 pub fn get_key(&self, id: usize) -> Option<String> {
487 self.id_to_key.read().unwrap().get(&id).cloned()
488 }
489
490 pub fn get_id(&self, key: &str) -> Option<usize> {
491 self.key_to_id.read().unwrap().get(key).copied()
492 }
493
494 pub fn len(&self) -> usize {
495 self.key_to_id.read().unwrap().len()
496 }
497
498 pub fn is_empty(&self) -> bool {
499 self.len() == 0
500 }
501
502 pub fn compact(&self) -> Result<usize> {
504 let embeddings = self.embeddings.read().unwrap();
505 let current_keys: std::collections::HashSet<_> =
506 self.key_to_id.read().unwrap().keys().cloned().collect();
507
508 if current_keys.is_empty() && !embeddings.is_empty() {
509 return Ok(0);
510 }
511
512 let active_entries: Vec<_> = embeddings
514 .iter()
515 .filter(|e| current_keys.contains(&e.key))
516 .cloned()
517 .collect();
518
519 let removed = embeddings.len() - active_entries.len();
520
521 if removed == 0 {
522 return Ok(0);
523 }
524
525 let mut new_index = hnsw::Hnsw::new(Euclidian);
527 let mut new_id_to_key = std::collections::HashMap::new();
528 let mut new_key_to_id = std::collections::HashMap::new();
529 let mut new_key_to_metadata = std::collections::HashMap::new();
530 let mut searcher = hnsw::Searcher::default();
531
532 for entry in &active_entries {
533 if entry.embedding.len() == self.dimensions {
534 let id = new_index.insert(entry.embedding.clone(), &mut searcher);
535 new_id_to_key.insert(id, entry.key.clone());
536 new_key_to_id.insert(entry.key.clone(), id);
537 new_key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
538 }
539 }
540
541 *self.index.write().unwrap() = new_index;
543 *self.id_to_key.write().unwrap() = new_id_to_key;
544 *self.key_to_id.write().unwrap() = new_key_to_id;
545 *self.key_to_metadata.write().unwrap() = new_key_to_metadata;
546
547 drop(embeddings);
549 *self.embeddings.write().unwrap() = active_entries;
550
551 let _ = self.save_vectors();
552
553 Ok(removed)
554 }
555
556 pub fn remove(&self, key: &str) -> bool {
558 let mut key_map = self.key_to_id.write().unwrap();
559 let mut id_map = self.id_to_key.write().unwrap();
560 let mut metadata_map = self.key_to_metadata.write().unwrap();
561
562 if let Some(id) = key_map.remove(key) {
563 id_map.remove(&id);
564 metadata_map.remove(key);
565 true
567 } else {
568 false
569 }
570 }
571
572 pub fn stats(&self) -> (usize, usize, usize) {
574 let embeddings_count = self.embeddings.read().unwrap().len();
575 let active_count = self.key_to_id.read().unwrap().len();
576 let stale_count = embeddings_count.saturating_sub(active_count);
577 (active_count, stale_count, embeddings_count)
578 }
579}