1use crate::persistence::{load_bincode, save_bincode};
2use anyhow::Result;
3use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
4use hnsw::Hnsw;
5use rand_pcg::Pcg64;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::{Arc, RwLock};
11
12const DEFAULT_DIMENSIONS: usize = 384;
13const DEFAULT_AUTO_SAVE_THRESHOLD: usize = 100;
14
15#[derive(Default, Clone)]
17pub struct Euclidian;
18
19impl space::Metric<Vec<f32>> for Euclidian {
20 type Unit = u32;
21 fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
22 let len = a.len().min(b.len());
23 let mut dist_sq = 0.0;
24 for i in 0..len {
25 let diff = a[i] - b[i];
26 dist_sq += diff * diff;
27 }
28 (dist_sq.sqrt() * 1_000_000.0) as u32
30 }
31}
32
33#[derive(Serialize, Deserialize, Default)]
35struct VectorData {
36 entries: Vec<VectorEntry>,
37}
38
39#[derive(Serialize, Deserialize, Clone)]
40struct VectorEntry {
41 key: String,
43 embedding: Vec<f32>,
44 #[serde(default)]
46 metadata: serde_json::Value,
47}
48
49pub struct VectorStore {
51 index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
53 id_to_key: Arc<RwLock<HashMap<usize, String>>>,
55 key_to_id: Arc<RwLock<HashMap<String, usize>>>,
57 key_to_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
59 storage_path: Option<PathBuf>,
61 model: TextEmbedding,
63 dimensions: usize,
65 embeddings: Arc<RwLock<Vec<VectorEntry>>>,
67 dirty_count: Arc<AtomicUsize>,
69 auto_save_threshold: usize,
71}
72
73#[derive(Debug, Serialize, Deserialize)]
74pub struct SearchResult {
75 pub key: String,
77 pub score: f32,
78 pub metadata: serde_json::Value,
80 pub uri: String,
82}
83
84impl VectorStore {
85 pub fn new(namespace: &str) -> Result<Self> {
87 let storage_path = std::env::var("GRAPH_STORAGE_PATH")
89 .ok()
90 .map(|p| PathBuf::from(p).join(namespace));
91
92 let dimensions = std::env::var("VECTOR_DIMENSIONS")
94 .ok()
95 .and_then(|s| s.parse().ok())
96 .unwrap_or(DEFAULT_DIMENSIONS);
97
98 let model_opts =
101 InitOptions::new(EmbeddingModel::BGESmallENV15).with_show_download_progress(true);
102
103 let model_opts = if let Ok(cache_path) = std::env::var("FASTEMBED_CACHE_PATH") {
105 model_opts.with_cache_dir(PathBuf::from(cache_path))
106 } else {
107 model_opts
108 };
109
110 let model = TextEmbedding::try_new(model_opts)?;
111
112 let mut index = Hnsw::new(Euclidian);
114 let mut id_to_key = HashMap::new();
115 let mut key_to_id = HashMap::new();
116 let mut key_to_metadata = HashMap::new();
117 let mut embeddings = Vec::new();
118
119 if let Some(ref path) = storage_path {
121 let vectors_bin = path.join("vectors.bin");
122 let vectors_json = path.join("vectors.json");
123
124 let loaded_data = if vectors_bin.exists() {
125 load_bincode::<VectorData>(&vectors_bin).ok()
126 } else if vectors_json.exists() {
127 let content = std::fs::read_to_string(&vectors_json).ok();
129 if let Some(content) = content {
130 if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
132 Some(data)
133 } else {
134 #[derive(Serialize, Deserialize)]
136 struct OldVectorData {
137 entries: Vec<OldVectorEntry>,
138 }
139 #[derive(Serialize, Deserialize)]
140 struct OldVectorEntry {
141 uri: String,
142 embedding: Vec<f32>,
143 }
144
145 if let Ok(old_data) = serde_json::from_str::<OldVectorData>(&content) {
146 let entries = old_data
147 .entries
148 .into_iter()
149 .map(|old| VectorEntry {
150 key: old.uri.clone(),
151 embedding: old.embedding,
152 metadata: serde_json::json!({ "uri": old.uri }),
153 })
154 .collect();
155 Some(VectorData { entries })
156 } else {
157 None
158 }
159 }
160 } else {
161 None
162 }
163 } else {
164 None
165 };
166
167 if let Some(data) = loaded_data {
168 let mut searcher = hnsw::Searcher::default();
169 for entry in data.entries {
170 if entry.embedding.len() == dimensions {
171 let id = index.insert(entry.embedding.clone(), &mut searcher);
172 id_to_key.insert(id, entry.key.clone());
173 key_to_id.insert(entry.key.clone(), id);
174 key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
175 embeddings.push(entry);
176 }
177 }
178 eprintln!(
179 "Loaded {} vectors from disk (dim={})",
180 embeddings.len(),
181 dimensions
182 );
183 }
184 }
185
186 Ok(Self {
187 index: Arc::new(RwLock::new(index)),
188 id_to_key: Arc::new(RwLock::new(id_to_key)),
189 key_to_id: Arc::new(RwLock::new(key_to_id)),
190 key_to_metadata: Arc::new(RwLock::new(key_to_metadata)),
191 storage_path,
192 model,
193 dimensions,
194 embeddings: Arc::new(RwLock::new(embeddings)),
195 dirty_count: Arc::new(AtomicUsize::new(0)),
196 auto_save_threshold: DEFAULT_AUTO_SAVE_THRESHOLD,
197 })
198 }
199
200 fn save_vectors(&self) -> Result<()> {
202 if let Some(ref path) = self.storage_path {
203 std::fs::create_dir_all(path)?;
204
205 let (entries, current_dirty) = {
206 let guard = self.embeddings.read().unwrap();
207 (guard.clone(), self.dirty_count.load(Ordering::Relaxed))
208 };
209
210 let data = VectorData { entries };
211 save_bincode(&path.join("vectors.bin"), &data)?;
212
213 if current_dirty > 0 {
214 let _ = self.dirty_count.fetch_sub(current_dirty, Ordering::Relaxed);
215 }
216 }
217 Ok(())
218 }
219
220 pub fn flush(&self) -> Result<()> {
222 self.save_vectors()
223 }
224
225 pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
227 let embeddings = self.embed_batch(vec![text.to_string()]).await?;
228 Ok(embeddings[0].clone())
229 }
230
231 pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
233 if texts.is_empty() {
234 return Ok(Vec::new());
235 }
236
237 let embeddings = self.model.embed(texts, None)?;
241
242 let mut results = Vec::new();
243 for item in embeddings {
244 if item.len() != self.dimensions {
245 anyhow::bail!(
246 "Expected {} dimensions, got {}",
247 self.dimensions,
248 item.len()
249 );
250 }
251 results.push(item);
252 }
253
254 Ok(results)
255 }
256
257 pub async fn add(
259 &self,
260 key: &str,
261 content: &str,
262 metadata: serde_json::Value,
263 ) -> Result<usize> {
264 let results = self
265 .add_batch(vec![(key.to_string(), content.to_string(), metadata)])
266 .await?;
267 Ok(results[0])
268 }
269
270 pub async fn add_batch(
272 &self,
273 items: Vec<(String, String, serde_json::Value)>,
274 ) -> Result<Vec<usize>> {
275 let mut new_items = Vec::new();
277 let mut result_ids = vec![0; items.len()];
278 let mut new_indices = Vec::new(); {
281 let key_map = self.key_to_id.read().unwrap();
282 for (i, (key, content, _)) in items.iter().enumerate() {
283 if let Some(&id) = key_map.get(key) {
284 result_ids[i] = id;
285 } else {
286 new_items.push(content.clone());
287 new_indices.push(i);
288 }
289 }
290 }
291
292 if new_items.is_empty() {
293 return Ok(result_ids);
294 }
295
296 let embeddings = self.embed_batch(new_items).await?;
298
299 let mut ids_to_add = Vec::new();
300 let mut searcher = hnsw::Searcher::default();
301
302 {
304 let mut index = self.index.write().unwrap();
305 let mut key_map = self.key_to_id.write().unwrap();
306 let mut id_map = self.id_to_key.write().unwrap();
307 let mut metadata_map = self.key_to_metadata.write().unwrap();
308 let mut embs = self.embeddings.write().unwrap();
309
310 for (i, embedding) in embeddings.into_iter().enumerate() {
311 let original_idx = new_indices[i];
312 let (key, _, metadata) = &items[original_idx];
313
314 if let Some(&id) = key_map.get(key) {
316 result_ids[original_idx] = id;
317 continue;
318 }
319
320 let id = index.insert(embedding.clone(), &mut searcher);
321 key_map.insert(key.clone(), id);
322 id_map.insert(id, key.clone());
323 metadata_map.insert(key.clone(), metadata.clone());
324
325 embs.push(VectorEntry {
326 key: key.clone(),
327 embedding,
328 metadata: metadata.clone(),
329 });
330
331 result_ids[original_idx] = id;
332 ids_to_add.push(id);
333 }
334 }
335
336 if !ids_to_add.is_empty() {
337 let count = self
338 .dirty_count
339 .fetch_add(ids_to_add.len(), Ordering::Relaxed);
340 if count + ids_to_add.len() >= self.auto_save_threshold {
341 let _ = self.save_vectors();
342 }
343 }
344
345 Ok(result_ids)
346 }
347
348 pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
350 let query_embedding = self.embed(query).await?;
352
353 let mut searcher = hnsw::Searcher::default();
355
356 let index = self.index.read().unwrap();
357 let len = index.len();
358 if len == 0 {
359 return Ok(Vec::new());
360 }
361
362 let k = k.min(len);
363 let ef = k.max(50); let mut neighbors = vec![
366 space::Neighbor {
367 index: 0,
368 distance: u32::MAX
369 };
370 k
371 ];
372
373 let found_neighbors = index.nearest(&query_embedding, ef, &mut searcher, &mut neighbors);
374
375 let id_map = self.id_to_key.read().unwrap();
377 let metadata_map = self.key_to_metadata.read().unwrap();
378
379 let results: Vec<SearchResult> = found_neighbors
380 .iter()
381 .filter_map(|neighbor| {
382 id_map.get(&neighbor.index).map(|key| {
383 let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
384
385 let metadata = metadata_map
386 .get(key)
387 .cloned()
388 .unwrap_or(serde_json::Value::Null);
389
390 let uri = metadata
391 .get("uri")
392 .and_then(|v| v.as_str())
393 .unwrap_or(key)
394 .to_string();
395
396 SearchResult {
397 key: key.clone(),
398 score: 1.0 / (1.0 + score_f32),
399 metadata,
400 uri,
401 }
402 })
403 })
404 .collect();
405
406 Ok(results)
407 }
408
409 pub fn get_key(&self, id: usize) -> Option<String> {
410 self.id_to_key.read().unwrap().get(&id).cloned()
411 }
412
413 pub fn get_id(&self, key: &str) -> Option<usize> {
414 self.key_to_id.read().unwrap().get(key).copied()
415 }
416
417 pub fn len(&self) -> usize {
418 self.key_to_id.read().unwrap().len()
419 }
420
421 pub fn is_empty(&self) -> bool {
422 self.len() == 0
423 }
424
425 pub fn compact(&self) -> Result<usize> {
427 let embeddings = self.embeddings.read().unwrap();
428 let current_keys: std::collections::HashSet<_> =
429 self.key_to_id.read().unwrap().keys().cloned().collect();
430
431 if current_keys.is_empty() && !embeddings.is_empty() {
432 return Ok(0);
433 }
434
435 let active_entries: Vec<_> = embeddings
437 .iter()
438 .filter(|e| current_keys.contains(&e.key))
439 .cloned()
440 .collect();
441
442 let removed = embeddings.len() - active_entries.len();
443
444 if removed == 0 {
445 return Ok(0);
446 }
447
448 let mut new_index = hnsw::Hnsw::new(Euclidian);
450 let mut new_id_to_key = std::collections::HashMap::new();
451 let mut new_key_to_id = std::collections::HashMap::new();
452 let mut new_key_to_metadata = std::collections::HashMap::new();
453 let mut searcher = hnsw::Searcher::default();
454
455 for entry in &active_entries {
456 if entry.embedding.len() == self.dimensions {
457 let id = new_index.insert(entry.embedding.clone(), &mut searcher);
458 new_id_to_key.insert(id, entry.key.clone());
459 new_key_to_id.insert(entry.key.clone(), id);
460 new_key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
461 }
462 }
463
464 *self.index.write().unwrap() = new_index;
466 *self.id_to_key.write().unwrap() = new_id_to_key;
467 *self.key_to_id.write().unwrap() = new_key_to_id;
468 *self.key_to_metadata.write().unwrap() = new_key_to_metadata;
469
470 drop(embeddings);
472 *self.embeddings.write().unwrap() = active_entries;
473
474 let _ = self.save_vectors();
475
476 Ok(removed)
477 }
478
479 pub fn remove(&self, key: &str) -> bool {
481 let mut key_map = self.key_to_id.write().unwrap();
482 let mut id_map = self.id_to_key.write().unwrap();
483 let mut metadata_map = self.key_to_metadata.write().unwrap();
484
485 if let Some(id) = key_map.remove(key) {
486 id_map.remove(&id);
487 metadata_map.remove(key);
488 true
490 } else {
491 false
492 }
493 }
494
495 pub fn stats(&self) -> (usize, usize, usize) {
497 let embeddings_count = self.embeddings.read().unwrap().len();
498 let active_count = self.key_to_id.read().unwrap().len();
499 let stale_count = embeddings_count.saturating_sub(active_count);
500 (active_count, stale_count, embeddings_count)
501 }
502}