1use anyhow::Result;
2use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
3use hnsw::Hnsw;
4use rand_pcg::Pcg64;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::{Arc, RwLock};
10
11const DEFAULT_DIMENSIONS: usize = 384;
12const DEFAULT_AUTO_SAVE_THRESHOLD: usize = 100;
13
14#[derive(Default, Clone)]
16pub struct Euclidian;
17
18impl space::Metric<Vec<f32>> for Euclidian {
19 type Unit = u32;
20 fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
21 let len = a.len().min(b.len());
22 let mut dist_sq = 0.0;
23 for i in 0..len {
24 let diff = a[i] - b[i];
25 dist_sq += diff * diff;
26 }
27 (dist_sq.sqrt() * 1_000_000.0) as u32
29 }
30}
31
32#[derive(Serialize, Deserialize, Default)]
34struct VectorData {
35 entries: Vec<VectorEntry>,
36}
37
38#[derive(Serialize, Deserialize, Clone)]
39struct VectorEntry {
40 key: String,
42 embedding: Vec<f32>,
43 #[serde(default)]
45 metadata_json: String,
46}
47
48pub struct VectorStore {
50 index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
52 id_to_key: Arc<RwLock<HashMap<usize, String>>>,
54 key_to_id: Arc<RwLock<HashMap<String, usize>>>,
56 key_to_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
58 storage_path: Option<PathBuf>,
60 model: TextEmbedding,
62 dimensions: usize,
64 embeddings: Arc<RwLock<Vec<VectorEntry>>>,
66 dirty_count: Arc<AtomicUsize>,
68 auto_save_threshold: usize,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73pub struct SearchResult {
74 pub key: String,
76 pub score: f32,
77 pub metadata: serde_json::Value,
79 pub uri: String,
81}
82
83impl VectorStore {
84 pub fn new(namespace: &str) -> Result<Self> {
86 let storage_path = std::env::var("GRAPH_STORAGE_PATH")
88 .ok()
89 .map(|p| PathBuf::from(p).join(namespace));
90
91 let dimensions = std::env::var("VECTOR_DIMENSIONS")
93 .ok()
94 .and_then(|s| s.parse().ok())
95 .unwrap_or(DEFAULT_DIMENSIONS);
96
97 let mut model_opts =
99 InitOptions::new(EmbeddingModel::BGESmallENV15).with_show_download_progress(true);
100
101 if let Ok(cache_path) = std::env::var("FASTEMBED_CACHE_PATH") {
102 model_opts = model_opts.with_cache_dir(PathBuf::from(cache_path));
103 }
104
105 let model = TextEmbedding::try_new(model_opts)?;
106
107 let mut index = Hnsw::new(Euclidian);
109 let mut id_to_key = HashMap::new();
110 let mut key_to_id = HashMap::new();
111 let mut key_to_metadata = HashMap::new();
112 let mut embeddings = Vec::new();
113
114 if let Some(ref path) = storage_path {
116 let vectors_json = path.join("vectors.json");
117
118 let loaded_data = if vectors_json.exists() {
119 match std::fs::read_to_string(&vectors_json) {
120 Ok(content) => match serde_json::from_str::<VectorData>(&content) {
121 Ok(data) => Some(data),
122 Err(e) => {
123 eprintln!("ERROR: Failed to parse vectors: {}", e);
124 None
125 }
126 },
127 Err(_) => None,
128 }
129 } else {
130 None
131 };
132
133 if let Some(data) = loaded_data {
134 let mut searcher = hnsw::Searcher::default();
135 for entry in data.entries {
136 if entry.embedding.len() == dimensions {
137 let id = index.insert(entry.embedding.clone(), &mut searcher);
138 id_to_key.insert(id, entry.key.clone());
139 key_to_id.insert(entry.key.clone(), id);
140
141 let metadata = serde_json::from_str(&entry.metadata_json).unwrap_or(serde_json::Value::Null);
142 key_to_metadata.insert(entry.key.clone(), metadata);
143 embeddings.push(entry);
144 }
145 }
146 eprintln!("Loaded {} vectors from disk", embeddings.len());
147 }
148 }
149
150 Ok(Self {
151 index: Arc::new(RwLock::new(index)),
152 id_to_key: Arc::new(RwLock::new(id_to_key)),
153 key_to_id: Arc::new(RwLock::new(key_to_id)),
154 key_to_metadata: Arc::new(RwLock::new(key_to_metadata)),
155 storage_path,
156 model,
157 dimensions,
158 embeddings: Arc::new(RwLock::new(embeddings)),
159 dirty_count: Arc::new(AtomicUsize::new(0)),
160 auto_save_threshold: DEFAULT_AUTO_SAVE_THRESHOLD,
161 })
162 }
163
164 fn save_vectors(&self) -> Result<()> {
166 if let Some(ref path) = self.storage_path {
167 std::fs::create_dir_all(path)?;
168
169 let (entries, current_dirty) = {
170 let guard = self.embeddings.read().unwrap();
171 (guard.clone(), self.dirty_count.load(Ordering::Relaxed))
172 };
173
174 let data = VectorData { entries };
175 let json = serde_json::to_string_pretty(&data)?;
176 std::fs::write(path.join("vectors.json"), json)?;
177
178 if current_dirty > 0 {
179 let _ = self.dirty_count.fetch_sub(current_dirty, Ordering::Relaxed);
180 }
181 }
182 Ok(())
183 }
184
185 pub fn flush(&self) -> Result<()> {
186 self.save_vectors()
187 }
188
189 pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
190 let embeddings = self.embed_batch(vec![text.to_string()]).await?;
191 Ok(embeddings[0].clone())
192 }
193
194 pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
195 if texts.is_empty() {
196 return Ok(Vec::new());
197 }
198 let embeddings = self.model.embed(texts, None)?;
199 Ok(embeddings)
200 }
201
202 pub async fn add(
203 &self,
204 key: &str,
205 content: &str,
206 metadata: serde_json::Value,
207 ) -> Result<usize> {
208 let results = self
209 .add_batch(vec![(key.to_string(), content.to_string(), metadata)])
210 .await?;
211 Ok(results[0])
212 }
213
214 pub async fn add_batch(
215 &self,
216 items: Vec<(String, String, serde_json::Value)>,
217 ) -> Result<Vec<usize>> {
218 let mut new_items = Vec::new();
219 let mut result_ids = vec![0; items.len()];
220 let mut new_indices = Vec::new();
221
222 {
223 let key_map = self.key_to_id.read().unwrap();
224 for (i, (key, content, _)) in items.iter().enumerate() {
225 if let Some(&id) = key_map.get(key) {
226 result_ids[i] = id;
227 } else {
228 new_items.push(content.clone());
229 new_indices.push(i);
230 }
231 }
232 }
233
234 if new_items.is_empty() {
235 return Ok(result_ids);
236 }
237
238 let embeddings = self.embed_batch(new_items).await?;
239 let mut ids_to_add = Vec::new();
240 let mut searcher = hnsw::Searcher::default();
241
242 {
243 let mut index = self.index.write().unwrap();
244 let mut key_map = self.key_to_id.write().unwrap();
245 let mut id_map = self.id_to_key.write().unwrap();
246 let mut metadata_map = self.key_to_metadata.write().unwrap();
247 let mut embs = self.embeddings.write().unwrap();
248
249 for (i, embedding) in embeddings.into_iter().enumerate() {
250 let original_idx = new_indices[i];
251 let (key, _, metadata) = &items[original_idx];
252
253 if let Some(&id) = key_map.get(key) {
254 result_ids[original_idx] = id;
255 continue;
256 }
257
258 let id = index.insert(embedding.clone(), &mut searcher);
259 key_map.insert(key.clone(), id);
260 id_map.insert(id, key.clone());
261 metadata_map.insert(key.clone(), metadata.clone());
262
263 embs.push(VectorEntry {
264 key: key.clone(),
265 embedding,
266 metadata_json: serde_json::to_string(metadata).unwrap_or_default(),
267 });
268
269 result_ids[original_idx] = id;
270 ids_to_add.push(id);
271 }
272 }
273
274 if !ids_to_add.is_empty() {
275 let count = self
276 .dirty_count
277 .fetch_add(ids_to_add.len(), Ordering::Relaxed);
278 if count + ids_to_add.len() >= self.auto_save_threshold {
279 let _ = self.save_vectors();
280 }
281 }
282
283 Ok(result_ids)
284 }
285
286 pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
287 let query_embedding = self.embed(query).await?;
288 let mut searcher = hnsw::Searcher::default();
289
290 let index = self.index.read().unwrap();
291 let len = index.len();
292 if len == 0 {
293 return Ok(Vec::new());
294 }
295
296 let k = k.min(len);
297 let ef = k.max(50);
298
299 let mut neighbors = vec![
300 space::Neighbor {
301 index: 0,
302 distance: u32::MAX
303 };
304 k
305 ];
306
307 let found_neighbors = index.nearest(&query_embedding, ef, &mut searcher, &mut neighbors);
308
309 let id_map = self.id_to_key.read().unwrap();
310 let metadata_map = self.key_to_metadata.read().unwrap();
311
312 let results: Vec<SearchResult> = found_neighbors
313 .iter()
314 .filter_map(|neighbor| {
315 id_map.get(&neighbor.index).map(|key| {
316 let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
317 let metadata = metadata_map
318 .get(key)
319 .cloned()
320 .unwrap_or(serde_json::Value::Null);
321 let uri = metadata
322 .get("uri")
323 .and_then(|v| v.as_str())
324 .unwrap_or(key)
325 .to_string();
326
327 SearchResult {
328 key: key.clone(),
329 score: 1.0 / (1.0 + score_f32),
330 metadata,
331 uri,
332 }
333 })
334 })
335 .collect();
336
337 Ok(results)
338 }
339
340 pub fn get_key(&self, id: usize) -> Option<String> {
341 self.id_to_key.read().unwrap().get(&id).cloned()
342 }
343
344 pub fn get_id(&self, key: &str) -> Option<usize> {
345 self.key_to_id.read().unwrap().get(key).copied()
346 }
347
348 pub fn len(&self) -> usize {
349 self.key_to_id.read().unwrap().len()
350 }
351
352 pub fn is_empty(&self) -> bool {
353 self.len() == 0
354 }
355
356 pub fn compact(&self) -> Result<usize> {
357 let embeddings = self.embeddings.read().unwrap();
358 let current_keys: std::collections::HashSet<_> =
359 self.key_to_id.read().unwrap().keys().cloned().collect();
360
361 if current_keys.is_empty() && !embeddings.is_empty() {
362 return Ok(0);
363 }
364
365 let active_entries: Vec<_> = embeddings
366 .iter()
367 .filter(|e| current_keys.contains(&e.key))
368 .cloned()
369 .collect();
370
371 let removed = embeddings.len() - active_entries.len();
372 if removed == 0 {
373 return Ok(0);
374 }
375
376 let mut new_index = hnsw::Hnsw::new(Euclidian);
377 let mut new_id_to_key = std::collections::HashMap::new();
378 let mut new_key_to_id = std::collections::HashMap::new();
379 let mut new_key_to_metadata = std::collections::HashMap::new();
380 let mut searcher = hnsw::Searcher::default();
381
382 for entry in &active_entries {
383 if entry.embedding.len() == self.dimensions {
384 let id = new_index.insert(entry.embedding.clone(), &mut searcher);
385 new_id_to_key.insert(id, entry.key.clone());
386 new_key_to_id.insert(entry.key.clone(), id);
387 let metadata = serde_json::from_str(&entry.metadata_json).unwrap_or(serde_json::Value::Null);
388 new_key_to_metadata.insert(entry.key.clone(), metadata);
389 }
390 }
391
392 *self.index.write().unwrap() = new_index;
393 *self.id_to_key.write().unwrap() = new_id_to_key;
394 *self.key_to_id.write().unwrap() = new_key_to_id;
395 *self.key_to_metadata.write().unwrap() = new_key_to_metadata;
396
397 drop(embeddings);
398 *self.embeddings.write().unwrap() = active_entries;
399 let _ = self.save_vectors();
400 Ok(removed)
401 }
402
403 pub fn remove(&self, key: &str) -> bool {
404 let mut key_map = self.key_to_id.write().unwrap();
405 let mut id_map = self.id_to_key.write().unwrap();
406 let mut metadata_map = self.key_to_metadata.write().unwrap();
407
408 if let Some(id) = key_map.remove(key) {
409 id_map.remove(&id);
410 metadata_map.remove(key);
411 true
412 } else {
413 false
414 }
415 }
416
417 pub fn stats(&self) -> (usize, usize, usize) {
418 let embeddings_count = self.embeddings.read().unwrap().len();
419 let active_count = self.key_to_id.read().unwrap().len();
420 let stale_count = embeddings_count.saturating_sub(active_count);
421 (active_count, stale_count, embeddings_count)
422 }
423}