Skip to main content

synapse_core/
vector_store.rs

1use crate::persistence::{load_bincode, save_bincode};
2use anyhow::Result;
3use fastembed::{InitOptions, TextEmbedding, EmbeddingModel};
4use hnsw::Hnsw;
5use rand_pcg::Pcg64;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::{Arc, RwLock};
11
12const DEFAULT_DIMENSIONS: usize = 384;
13const DEFAULT_AUTO_SAVE_THRESHOLD: usize = 100;
14
15/// Euclidean distance metric for HNSW
16#[derive(Default, Clone)]
17pub struct Euclidian;
18
19impl space::Metric<Vec<f32>> for Euclidian {
20    type Unit = u32;
21    fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
22        let len = a.len().min(b.len());
23        let mut dist_sq = 0.0;
24        for i in 0..len {
25            let diff = a[i] - b[i];
26            dist_sq += diff * diff;
27        }
28        // Use fixed-point arithmetic to avoid bitwise comparison issues
29        (dist_sq.sqrt() * 1_000_000.0) as u32
30    }
31}
32
33/// Persisted vector data
34#[derive(Serialize, Deserialize, Default)]
35struct VectorData {
36    entries: Vec<VectorEntry>,
37}
38
39#[derive(Serialize, Deserialize, Clone)]
40struct VectorEntry {
41    /// Unique identifier for this vector (could be URI or Hash)
42    key: String,
43    embedding: Vec<f32>,
44    /// Optional metadata associated with the vector
45    #[serde(default)]
46    metadata: serde_json::Value,
47}
48
49/// Vector store using Local FastEmbed for embeddings
50pub struct VectorStore {
51    /// HNSW index for fast approximate nearest neighbor search
52    index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
53    /// Mapping from node ID (internal) to Key
54    id_to_key: Arc<RwLock<HashMap<usize, String>>>,
55    /// Mapping from Key to node ID (internal)
56    key_to_id: Arc<RwLock<HashMap<String, usize>>>,
57    /// Mapping from Key to Metadata (for fast retrieval)
58    key_to_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
59    /// Storage path for persistence
60    storage_path: Option<PathBuf>,
61    /// Local embedding model
62    model: TextEmbedding,
63    /// Vector dimensions
64    dimensions: usize,
65    /// Stored embeddings for persistence
66    embeddings: Arc<RwLock<Vec<VectorEntry>>>,
67    /// Number of unsaved changes
68    dirty_count: Arc<AtomicUsize>,
69    /// Threshold for auto-save
70    auto_save_threshold: usize,
71}
72
73#[derive(Debug, Serialize, Deserialize)]
74pub struct SearchResult {
75    /// The unique key
76    pub key: String,
77    pub score: f32,
78    /// Metadata (including original URI if applicable)
79    pub metadata: serde_json::Value,
80    // Helper to access URI from metadata if it exists, for backward compatibility
81    pub uri: String,
82}
83
84impl VectorStore {
85    /// Create a new vector store for a namespace
86    pub fn new(namespace: &str) -> Result<Self> {
87        // Try to get storage path from environment
88        let storage_path = std::env::var("GRAPH_STORAGE_PATH")
89            .ok()
90            .map(|p| PathBuf::from(p).join(namespace));
91
92        // Get dimensions from env or default
93        let dimensions = std::env::var("VECTOR_DIMENSIONS")
94            .ok()
95            .and_then(|s| s.parse().ok())
96            .unwrap_or(DEFAULT_DIMENSIONS);
97
98        // Initialize FastEmbed model
99        // We use BGESmallENV15 as it is small and effective (384 dims)
100        let model_opts = InitOptions::new(EmbeddingModel::BGESmallENV15)
101            .with_show_download_progress(true);
102        
103        // Use custom cache dir if specified
104        let model_opts = if let Ok(cache_path) = std::env::var("FASTEMBED_CACHE_PATH") {
105             model_opts.with_cache_dir(PathBuf::from(cache_path))
106        } else {
107             model_opts
108        };
109
110        let model = TextEmbedding::try_new(model_opts)?;
111
112        // Create HNSW index with Euclidian metric
113        let mut index = Hnsw::new(Euclidian);
114        let mut id_to_key = HashMap::new();
115        let mut key_to_id = HashMap::new();
116        let mut key_to_metadata = HashMap::new();
117        let mut embeddings = Vec::new();
118
119        // Try to load persisted vectors
120        if let Some(ref path) = storage_path {
121            let vectors_bin = path.join("vectors.bin");
122            let vectors_json = path.join("vectors.json");
123
124            let loaded_data = if vectors_bin.exists() {
125                load_bincode::<VectorData>(&vectors_bin).ok()
126            } else if vectors_json.exists() {
127                // Fallback / Migration from JSON
128                let content = std::fs::read_to_string(&vectors_json).ok();
129                if let Some(content) = content {
130                    // Try new format first
131                    if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
132                        Some(data)
133                    } else {
134                        // Fallback: Try loading old format
135                        #[derive(Serialize, Deserialize)]
136                        struct OldVectorData {
137                            entries: Vec<OldVectorEntry>,
138                        }
139                        #[derive(Serialize, Deserialize)]
140                        struct OldVectorEntry {
141                            uri: String,
142                            embedding: Vec<f32>,
143                        }
144
145                        if let Ok(old_data) = serde_json::from_str::<OldVectorData>(&content) {
146                            let entries = old_data
147                                .entries
148                                .into_iter()
149                                .map(|old| VectorEntry {
150                                    key: old.uri.clone(),
151                                    embedding: old.embedding,
152                                    metadata: serde_json::json!({ "uri": old.uri }),
153                                })
154                                .collect();
155                            Some(VectorData { entries })
156                        } else {
157                            None
158                        }
159                    }
160                } else {
161                    None
162                }
163            } else {
164                None
165            };
166
167            if let Some(data) = loaded_data {
168                let mut searcher = hnsw::Searcher::default();
169                for entry in data.entries {
170                    if entry.embedding.len() == dimensions {
171                        let id = index.insert(entry.embedding.clone(), &mut searcher);
172                        id_to_key.insert(id, entry.key.clone());
173                        key_to_id.insert(entry.key.clone(), id);
174                        key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
175                        embeddings.push(entry);
176                    }
177                }
178                eprintln!(
179                    "Loaded {} vectors from disk (dim={})",
180                    embeddings.len(),
181                    dimensions
182                );
183            }
184        }
185
186        Ok(Self {
187            index: Arc::new(RwLock::new(index)),
188            id_to_key: Arc::new(RwLock::new(id_to_key)),
189            key_to_id: Arc::new(RwLock::new(key_to_id)),
190            key_to_metadata: Arc::new(RwLock::new(key_to_metadata)),
191            storage_path,
192            model,
193            dimensions,
194            embeddings: Arc::new(RwLock::new(embeddings)),
195            dirty_count: Arc::new(AtomicUsize::new(0)),
196            auto_save_threshold: DEFAULT_AUTO_SAVE_THRESHOLD,
197        })
198    }
199
200    /// Save vectors to disk
201    fn save_vectors(&self) -> Result<()> {
202        if let Some(ref path) = self.storage_path {
203            std::fs::create_dir_all(path)?;
204
205            let (entries, current_dirty) = {
206                let guard = self.embeddings.read().unwrap();
207                (guard.clone(), self.dirty_count.load(Ordering::Relaxed))
208            };
209
210            let data = VectorData { entries };
211            save_bincode(&path.join("vectors.bin"), &data)?;
212
213            if current_dirty > 0 {
214                let _ = self.dirty_count.fetch_sub(current_dirty, Ordering::Relaxed);
215            }
216        }
217        Ok(())
218    }
219
220    /// Force save to disk
221    pub fn flush(&self) -> Result<()> {
222        self.save_vectors()
223    }
224
225    /// Generate embedding for a text using local model
226    pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
227        let embeddings = self.embed_batch(vec![text.to_string()]).await?;
228        Ok(embeddings[0].clone())
229    }
230
231    /// Generate embeddings for multiple texts using local model
232    pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
233        if texts.is_empty() {
234            return Ok(Vec::new());
235        }
236
237        // Generate embeddings using local model (CPU)
238        // Note: model.embed is not async, but it's fast enough for small batches
239        // We could wrap it in spawn_blocking if needed, but for now we'll keep it simple
240        let embeddings = self.model.embed(texts, None)?;
241        
242        let mut results = Vec::new();
243        for item in embeddings {
244            if item.len() != self.dimensions {
245                anyhow::bail!("Expected {} dimensions, got {}", self.dimensions, item.len());
246            }
247            results.push(item);
248        }
249
250        Ok(results)
251    }
252
253    /// Add a key with its text content to the index
254    pub async fn add(
255        &self,
256        key: &str,
257        content: &str,
258        metadata: serde_json::Value,
259    ) -> Result<usize> {
260        let results = self
261            .add_batch(vec![(key.to_string(), content.to_string(), metadata)])
262            .await?;
263        Ok(results[0])
264    }
265
266    /// Add multiple keys with their text content to the index
267    pub async fn add_batch(
268        &self,
269        items: Vec<(String, String, serde_json::Value)>,
270    ) -> Result<Vec<usize>> {
271        // Filter out existing keys
272        let mut new_items = Vec::new();
273        let mut result_ids = vec![0; items.len()];
274        let mut new_indices = Vec::new(); // Map index in new_items to index in items
275
276        {
277            let key_map = self.key_to_id.read().unwrap();
278            for (i, (key, content, _)) in items.iter().enumerate() {
279                if let Some(&id) = key_map.get(key) {
280                    result_ids[i] = id;
281                } else {
282                    new_items.push(content.clone());
283                    new_indices.push(i);
284                }
285            }
286        }
287
288        if new_items.is_empty() {
289            return Ok(result_ids);
290        }
291
292        // Generate embeddings via fastembed
293        let embeddings = self.embed_batch(new_items).await?;
294
295        let mut ids_to_add = Vec::new();
296        let mut searcher = hnsw::Searcher::default();
297
298        // Add to HNSW index and maps
299        {
300            let mut index = self.index.write().unwrap();
301            let mut key_map = self.key_to_id.write().unwrap();
302            let mut id_map = self.id_to_key.write().unwrap();
303            let mut metadata_map = self.key_to_metadata.write().unwrap();
304            let mut embs = self.embeddings.write().unwrap();
305
306            for (i, embedding) in embeddings.into_iter().enumerate() {
307                let original_idx = new_indices[i];
308                let (key, _, metadata) = &items[original_idx];
309
310                // Double check if inserted in race condition
311                if let Some(&id) = key_map.get(key) {
312                    result_ids[original_idx] = id;
313                    continue;
314                }
315
316                let id = index.insert(embedding.clone(), &mut searcher);
317                key_map.insert(key.clone(), id);
318                id_map.insert(id, key.clone());
319                metadata_map.insert(key.clone(), metadata.clone());
320
321                embs.push(VectorEntry {
322                    key: key.clone(),
323                    embedding,
324                    metadata: metadata.clone(),
325                });
326
327                result_ids[original_idx] = id;
328                ids_to_add.push(id);
329            }
330        }
331
332        if !ids_to_add.is_empty() {
333            let count = self
334                .dirty_count
335                .fetch_add(ids_to_add.len(), Ordering::Relaxed);
336            if count + ids_to_add.len() >= self.auto_save_threshold {
337                let _ = self.save_vectors();
338            }
339        }
340
341        Ok(result_ids)
342    }
343
344    /// Search for similar vectors
345    pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
346        // Generate query embedding
347        let query_embedding = self.embed(query).await?;
348
349        // Search HNSW index
350        let mut searcher = hnsw::Searcher::default();
351
352        let index = self.index.read().unwrap();
353        let len = index.len();
354        if len == 0 {
355            return Ok(Vec::new());
356        }
357
358        let k = k.min(len);
359        let ef = k.max(50); // Ensure ef is at least k, but usually 50+
360
361        let mut neighbors = vec![
362            space::Neighbor {
363                index: 0,
364                distance: u32::MAX
365            };
366            k
367        ];
368
369        let found_neighbors = index.nearest(&query_embedding, ef, &mut searcher, &mut neighbors);
370
371        // Convert to results
372        let id_map = self.id_to_key.read().unwrap();
373        let metadata_map = self.key_to_metadata.read().unwrap();
374
375        let results: Vec<SearchResult> = found_neighbors
376            .iter()
377            .filter_map(|neighbor| {
378                id_map.get(&neighbor.index).map(|key| {
379                    let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
380
381                    let metadata = metadata_map
382                        .get(key)
383                        .cloned()
384                        .unwrap_or(serde_json::Value::Null);
385
386                    let uri = metadata
387                        .get("uri")
388                        .and_then(|v| v.as_str())
389                        .unwrap_or(key)
390                        .to_string();
391
392                    SearchResult {
393                        key: key.clone(),
394                        score: 1.0 / (1.0 + score_f32),
395                        metadata,
396                        uri,
397                    }
398                })
399            })
400            .collect();
401
402        Ok(results)
403    }
404
405    pub fn get_key(&self, id: usize) -> Option<String> {
406        self.id_to_key.read().unwrap().get(&id).cloned()
407    }
408
409    pub fn get_id(&self, key: &str) -> Option<usize> {
410        self.key_to_id.read().unwrap().get(key).copied()
411    }
412
413    pub fn len(&self) -> usize {
414        self.key_to_id.read().unwrap().len()
415    }
416
417    pub fn is_empty(&self) -> bool {
418        self.len() == 0
419    }
420
421    /// Compaction: rebuild index from stored embeddings, removing stale entries
422    pub fn compact(&self) -> Result<usize> {
423        let embeddings = self.embeddings.read().unwrap();
424        let current_keys: std::collections::HashSet<_> =
425            self.key_to_id.read().unwrap().keys().cloned().collect();
426
427        if current_keys.is_empty() && !embeddings.is_empty() {
428            return Ok(0);
429        }
430
431        // Filter to only current Keys
432        let active_entries: Vec<_> = embeddings
433            .iter()
434            .filter(|e| current_keys.contains(&e.key))
435            .cloned()
436            .collect();
437
438        let removed = embeddings.len() - active_entries.len();
439
440        if removed == 0 {
441            return Ok(0);
442        }
443
444        // Rebuild index
445        let mut new_index = hnsw::Hnsw::new(Euclidian);
446        let mut new_id_to_key = std::collections::HashMap::new();
447        let mut new_key_to_id = std::collections::HashMap::new();
448        let mut new_key_to_metadata = std::collections::HashMap::new();
449        let mut searcher = hnsw::Searcher::default();
450
451        for entry in &active_entries {
452            if entry.embedding.len() == self.dimensions {
453                let id = new_index.insert(entry.embedding.clone(), &mut searcher);
454                new_id_to_key.insert(id, entry.key.clone());
455                new_key_to_id.insert(entry.key.clone(), id);
456                new_key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
457            }
458        }
459
460        // Swap in new index
461        *self.index.write().unwrap() = new_index;
462        *self.id_to_key.write().unwrap() = new_id_to_key;
463        *self.key_to_id.write().unwrap() = new_key_to_id;
464        *self.key_to_metadata.write().unwrap() = new_key_to_metadata;
465
466        // Update embeddings (drop takes write lock)
467        drop(embeddings);
468        *self.embeddings.write().unwrap() = active_entries;
469
470        let _ = self.save_vectors();
471
472        Ok(removed)
473    }
474
475    /// Remove a Key from the vector store
476    pub fn remove(&self, key: &str) -> bool {
477        let mut key_map = self.key_to_id.write().unwrap();
478        let mut id_map = self.id_to_key.write().unwrap();
479        let mut metadata_map = self.key_to_metadata.write().unwrap();
480
481        if let Some(id) = key_map.remove(key) {
482            id_map.remove(&id);
483            metadata_map.remove(key);
484            // Note: actual index entry remains until compaction
485            true
486        } else {
487            false
488        }
489    }
490
491    /// Get storage stats
492    pub fn stats(&self) -> (usize, usize, usize) {
493        let embeddings_count = self.embeddings.read().unwrap().len();
494        let active_count = self.key_to_id.read().unwrap().len();
495        let stale_count = embeddings_count.saturating_sub(active_count);
496        (active_count, stale_count, embeddings_count)
497    }
498}