1use crate::persistence::{load_bincode, save_bincode};
2use anyhow::Result;
3use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
4use hnsw::Hnsw;
5use rand_pcg::Pcg64;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::PathBuf;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::{Arc, Mutex, RwLock};
12
13const HUGGINGFACE_API_URL: &str = "https://router.huggingface.co/hf-inference/models";
14const DEFAULT_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2"; const DEFAULT_DIMENSIONS: usize = 384;
16const DEFAULT_AUTO_SAVE_THRESHOLD: usize = 100;
17
18#[derive(Default, Clone)]
20pub struct Euclidian;
21
22impl space::Metric<Vec<f32>> for Euclidian {
23 type Unit = u32;
24 fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
25 let len = a.len().min(b.len());
26 let mut dist_sq = 0.0;
27 for i in 0..len {
28 let diff = a[i] - b[i];
29 dist_sq += diff * diff;
30 }
31 (dist_sq.sqrt() * 1_000_000.0) as u32
33 }
34}
35
36#[derive(Serialize, Deserialize, Default)]
38struct VectorData {
39 entries: Vec<VectorEntry>,
40}
41
42#[derive(Serialize, Deserialize, Clone)]
43struct VectorEntry {
44 key: String,
46 embedding: Vec<f32>,
47 #[serde(default)]
49 metadata: serde_json::Value,
50}
51
52pub struct VectorStore {
54 index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
56 id_to_key: Arc<RwLock<HashMap<usize, String>>>,
58 key_to_id: Arc<RwLock<HashMap<String, usize>>>,
60 key_to_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
62 storage_path: Option<PathBuf>,
64 client: Client,
66 api_url: String,
68 api_token: Option<String>,
70 model: String,
72 dimensions: usize,
74 embeddings: Arc<RwLock<Vec<VectorEntry>>>,
76 dirty_count: Arc<AtomicUsize>,
78 auto_save_threshold: usize,
80 local_model: Option<Arc<Mutex<TextEmbedding>>>,
82}
83
84#[derive(Debug, Serialize, Deserialize)]
85pub struct SearchResult {
86 pub key: String,
88 pub score: f32,
89 pub metadata: serde_json::Value,
91 pub uri: String,
93}
94
95impl VectorStore {
96 pub fn new(namespace: &str) -> Result<Self> {
98 let storage_path = std::env::var("GRAPH_STORAGE_PATH")
100 .ok()
101 .map(|p| PathBuf::from(p).join(namespace));
102
103 let dimensions = std::env::var("VECTOR_DIMENSIONS")
105 .ok()
106 .and_then(|s| s.parse().ok())
107 .unwrap_or(DEFAULT_DIMENSIONS);
108
109 let mut index = Hnsw::new(Euclidian);
111 let mut id_to_key = HashMap::new();
112 let mut key_to_id = HashMap::new();
113 let mut key_to_metadata = HashMap::new();
114 let mut embeddings = Vec::new();
115
116 if let Some(ref path) = storage_path {
118 let vectors_bin = path.join("vectors.bin");
119 let vectors_json = path.join("vectors.json");
120
121 let loaded_data = if vectors_bin.exists() {
122 load_bincode::<VectorData>(&vectors_bin).ok()
123 } else if vectors_json.exists() {
124 let content = std::fs::read_to_string(&vectors_json).ok();
126 if let Some(content) = content {
127 if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
129 Some(data)
130 } else {
131 #[derive(Serialize, Deserialize)]
133 struct OldVectorData {
134 entries: Vec<OldVectorEntry>,
135 }
136 #[derive(Serialize, Deserialize)]
137 struct OldVectorEntry {
138 uri: String,
139 embedding: Vec<f32>,
140 }
141
142 if let Ok(old_data) = serde_json::from_str::<OldVectorData>(&content) {
143 let entries = old_data
144 .entries
145 .into_iter()
146 .map(|old| VectorEntry {
147 key: old.uri.clone(),
148 embedding: old.embedding,
149 metadata: serde_json::json!({ "uri": old.uri }),
150 })
151 .collect();
152 Some(VectorData { entries })
153 } else {
154 None
155 }
156 }
157 } else {
158 None
159 }
160 } else {
161 None
162 };
163
164 if let Some(data) = loaded_data {
165 let mut searcher = hnsw::Searcher::default();
166 for entry in data.entries {
167 if entry.embedding.len() == dimensions {
168 let id = index.insert(entry.embedding.clone(), &mut searcher);
169 id_to_key.insert(id, entry.key.clone());
170 key_to_id.insert(entry.key.clone(), id);
171 key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
172 embeddings.push(entry);
173 }
174 }
175 eprintln!(
176 "Loaded {} vectors from disk (dim={})",
177 embeddings.len(),
178 dimensions
179 );
180 }
181 }
182
183 let api_token = std::env::var("HUGGINGFACE_API_TOKEN").ok();
185
186 let api_url = std::env::var("HUGGINGFACE_API_URL")
188 .unwrap_or_else(|_| HUGGINGFACE_API_URL.to_string());
189
190 let client = Client::builder()
192 .timeout(std::time::Duration::from_secs(30))
193 .build()
194 .unwrap_or_else(|_| Client::new());
195
196 let local_model = if api_token.is_none() && std::env::var("MOCK_EMBEDDINGS").is_err() {
197 eprintln!("Initializing local embedding model (this may take a moment to download)...");
198 let model = TextEmbedding::try_new(
199 InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_show_download_progress(true),
200 )?;
201 Some(Arc::new(Mutex::new(model)))
202 } else {
203 None
204 };
205
206 Ok(Self {
207 index: Arc::new(RwLock::new(index)),
208 id_to_key: Arc::new(RwLock::new(id_to_key)),
209 key_to_id: Arc::new(RwLock::new(key_to_id)),
210 key_to_metadata: Arc::new(RwLock::new(key_to_metadata)),
211 storage_path,
212 client,
213 api_url,
214 api_token,
215 model: DEFAULT_MODEL.to_string(),
216 dimensions,
217 embeddings: Arc::new(RwLock::new(embeddings)),
218 dirty_count: Arc::new(AtomicUsize::new(0)),
219 auto_save_threshold: DEFAULT_AUTO_SAVE_THRESHOLD,
220 local_model,
221 })
222 }
223
224 fn save_vectors(&self) -> Result<()> {
226 if let Some(ref path) = self.storage_path {
227 std::fs::create_dir_all(path)?;
228
229 let (entries, current_dirty) = {
251 let guard = self.embeddings.read().unwrap();
252 (guard.clone(), self.dirty_count.load(Ordering::Relaxed))
253 };
254
255 let data = VectorData { entries };
256 save_bincode(&path.join("vectors.bin"), &data)?;
257
258 if current_dirty > 0 {
259 let _ = self.dirty_count.fetch_sub(current_dirty, Ordering::Relaxed);
260 }
261 }
262 Ok(())
263 }
264
265 pub fn flush(&self) -> Result<()> {
267 self.save_vectors()
268 }
269
270 pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
273 if std::env::var("MOCK_EMBEDDINGS").is_ok() {
274 use rand::Rng;
276 let mut rng = rand::rng();
277 let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
278 return Ok(vec);
279 }
280
281 let embeddings = self.embed_batch(vec![text.to_string()]).await?;
282 Ok(embeddings[0].clone())
283 }
284
285 pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
287 if texts.is_empty() {
288 return Ok(Vec::new());
289 }
290
291 if std::env::var("MOCK_EMBEDDINGS").is_ok() {
292 use rand::Rng;
293 let mut rng = rand::rng();
294 let mut results = Vec::new();
295 for _ in 0..texts.len() {
296 let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
297 results.push(vec);
298 }
299 return Ok(results);
300 }
301
302 if let Some(ref model) = self.local_model {
303 let model = model.clone();
304 let texts = texts.clone();
305 let embeddings = tokio::task::spawn_blocking(move || {
306 let mut model = model.lock().unwrap();
307 model.embed(texts, None)
308 })
309 .await??;
310 return Ok(embeddings);
311 }
312
313 let url = format!(
314 "{}/{}/pipeline/feature-extraction",
315 self.api_url, self.model
316 );
317
318 let mut request = self.client.post(&url).json(&serde_json::json!({
320 "inputs": texts,
321 }));
322
323 if let Some(ref token) = self.api_token {
325 request = request.header("Authorization", format!("Bearer {}", token));
326 }
327
328 let response = request.send().await?;
329
330 if !response.status().is_success() {
331 let error_text = response.text().await?;
332 anyhow::bail!("HuggingFace API error: {}", error_text);
333 }
334
335 let response_json: serde_json::Value = response.json().await?;
336 let mut results = Vec::new();
337
338 if let Some(arr) = response_json.as_array() {
339 for item in arr {
340 let vec: Vec<f32> = serde_json::from_value(item.clone())
341 .map_err(|e| anyhow::anyhow!("Failed to parse embedding: {}", e))?;
342
343 if vec.len() != self.dimensions {
344 anyhow::bail!("Expected {} dimensions, got {}", self.dimensions, vec.len());
345 }
346 results.push(vec);
347 }
348 } else {
349 if texts.len() == 1 {
351 if let Ok(vec) = serde_json::from_value::<Vec<f32>>(response_json) {
352 if vec.len() == self.dimensions {
353 results.push(vec);
354 }
355 }
356 }
357 }
358
359 if results.len() != texts.len() {
360 anyhow::bail!("Expected {} embeddings, got {}", texts.len(), results.len());
361 }
362
363 Ok(results)
364 }
365
366 pub async fn add(
368 &self,
369 key: &str,
370 content: &str,
371 metadata: serde_json::Value,
372 ) -> Result<usize> {
373 let results = self
374 .add_batch(vec![(key.to_string(), content.to_string(), metadata)])
375 .await?;
376 Ok(results[0])
377 }
378
379 pub async fn add_batch(
381 &self,
382 items: Vec<(String, String, serde_json::Value)>,
383 ) -> Result<Vec<usize>> {
384 let mut new_items = Vec::new();
386 let mut result_ids = vec![0; items.len()];
387 let mut new_indices = Vec::new(); {
390 let key_map = self.key_to_id.read().unwrap();
391 for (i, (key, content, _)) in items.iter().enumerate() {
392 if let Some(&id) = key_map.get(key) {
393 result_ids[i] = id;
394 } else {
395 new_items.push(content.clone());
396 new_indices.push(i);
397 }
398 }
399 }
400
401 if new_items.is_empty() {
402 return Ok(result_ids);
403 }
404
405 let embeddings = self.embed_batch(new_items).await?;
407
408 let mut ids_to_add = Vec::new();
409 let mut searcher = hnsw::Searcher::default();
410
411 {
413 let mut index = self.index.write().unwrap();
414 let mut key_map = self.key_to_id.write().unwrap();
415 let mut id_map = self.id_to_key.write().unwrap();
416 let mut metadata_map = self.key_to_metadata.write().unwrap();
417 let mut embs = self.embeddings.write().unwrap();
418
419 for (i, embedding) in embeddings.into_iter().enumerate() {
420 let original_idx = new_indices[i];
421 let (key, _, metadata) = &items[original_idx];
422
423 if let Some(&id) = key_map.get(key) {
425 result_ids[original_idx] = id;
426 continue;
427 }
428
429 let id = index.insert(embedding.clone(), &mut searcher);
430 key_map.insert(key.clone(), id);
431 id_map.insert(id, key.clone());
432 metadata_map.insert(key.clone(), metadata.clone());
433
434 embs.push(VectorEntry {
435 key: key.clone(),
436 embedding,
437 metadata: metadata.clone(),
438 });
439
440 result_ids[original_idx] = id;
441 ids_to_add.push(id);
442 }
443 }
444
445 if !ids_to_add.is_empty() {
446 let count = self
447 .dirty_count
448 .fetch_add(ids_to_add.len(), Ordering::Relaxed);
449 if count + ids_to_add.len() >= self.auto_save_threshold {
450 let _ = self.save_vectors();
451 }
452 }
453
454 Ok(result_ids)
455 }
456
457 pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
459 let query_embedding = self.embed(query).await?;
461
462 let mut searcher = hnsw::Searcher::default();
464
465 let index = self.index.read().unwrap();
466 let len = index.len();
467 if len == 0 {
468 return Ok(Vec::new());
469 }
470
471 let k = k.min(len);
472 let ef = k.max(50); let mut neighbors = vec![
475 space::Neighbor {
476 index: 0,
477 distance: u32::MAX
478 };
479 k
480 ];
481
482 let found_neighbors = index.nearest(&query_embedding, ef, &mut searcher, &mut neighbors);
483
484 let id_map = self.id_to_key.read().unwrap();
486 let metadata_map = self.key_to_metadata.read().unwrap();
487
488 let results: Vec<SearchResult> = found_neighbors
489 .iter()
490 .filter_map(|neighbor| {
491 id_map.get(&neighbor.index).map(|key| {
492 let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
493
494 let metadata = metadata_map
495 .get(key)
496 .cloned()
497 .unwrap_or(serde_json::Value::Null);
498
499 let uri = metadata
500 .get("uri")
501 .and_then(|v| v.as_str())
502 .unwrap_or(key)
503 .to_string();
504
505 SearchResult {
506 key: key.clone(),
507 score: 1.0 / (1.0 + score_f32),
508 metadata,
509 uri,
510 }
511 })
512 })
513 .collect();
514
515 Ok(results)
516 }
517
518 pub fn get_key(&self, id: usize) -> Option<String> {
519 self.id_to_key.read().unwrap().get(&id).cloned()
520 }
521
522 pub fn get_id(&self, key: &str) -> Option<usize> {
523 self.key_to_id.read().unwrap().get(key).copied()
524 }
525
526 pub fn len(&self) -> usize {
527 self.key_to_id.read().unwrap().len()
528 }
529
530 pub fn is_empty(&self) -> bool {
531 self.len() == 0
532 }
533
534 pub fn compact(&self) -> Result<usize> {
536 let embeddings = self.embeddings.read().unwrap();
537 let current_keys: std::collections::HashSet<_> =
538 self.key_to_id.read().unwrap().keys().cloned().collect();
539
540 if current_keys.is_empty() && !embeddings.is_empty() {
541 return Ok(0);
542 }
543
544 let active_entries: Vec<_> = embeddings
546 .iter()
547 .filter(|e| current_keys.contains(&e.key))
548 .cloned()
549 .collect();
550
551 let removed = embeddings.len() - active_entries.len();
552
553 if removed == 0 {
554 return Ok(0);
555 }
556
557 let mut new_index = hnsw::Hnsw::new(Euclidian);
559 let mut new_id_to_key = std::collections::HashMap::new();
560 let mut new_key_to_id = std::collections::HashMap::new();
561 let mut new_key_to_metadata = std::collections::HashMap::new();
562 let mut searcher = hnsw::Searcher::default();
563
564 for entry in &active_entries {
565 if entry.embedding.len() == self.dimensions {
566 let id = new_index.insert(entry.embedding.clone(), &mut searcher);
567 new_id_to_key.insert(id, entry.key.clone());
568 new_key_to_id.insert(entry.key.clone(), id);
569 new_key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
570 }
571 }
572
573 *self.index.write().unwrap() = new_index;
575 *self.id_to_key.write().unwrap() = new_id_to_key;
576 *self.key_to_id.write().unwrap() = new_key_to_id;
577 *self.key_to_metadata.write().unwrap() = new_key_to_metadata;
578
579 drop(embeddings);
581 *self.embeddings.write().unwrap() = active_entries;
582
583 let _ = self.save_vectors();
584
585 Ok(removed)
586 }
587
588 pub fn remove(&self, key: &str) -> bool {
590 let mut key_map = self.key_to_id.write().unwrap();
591 let mut id_map = self.id_to_key.write().unwrap();
592 let mut metadata_map = self.key_to_metadata.write().unwrap();
593
594 if let Some(id) = key_map.remove(key) {
595 id_map.remove(&id);
596 metadata_map.remove(key);
597 true
599 } else {
600 false
601 }
602 }
603
604 pub fn stats(&self) -> (usize, usize, usize) {
606 let embeddings_count = self.embeddings.read().unwrap().len();
607 let active_count = self.key_to_id.read().unwrap().len();
608 let stale_count = embeddings_count.saturating_sub(active_count);
609 (active_count, stale_count, embeddings_count)
610 }
611}