Skip to main content

synapse_core/
vector_store.rs

1use anyhow::Result;
2use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
3use hnsw::Hnsw;
4use rand_pcg::Pcg64;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::{Arc, RwLock};
10
11const DEFAULT_DIMENSIONS: usize = 384;
12const DEFAULT_AUTO_SAVE_THRESHOLD: usize = 100;
13
14/// Euclidean distance metric for HNSW
15#[derive(Default, Clone)]
16pub struct Euclidian;
17
18impl space::Metric<Vec<f32>> for Euclidian {
19    type Unit = u32;
20    fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
21        let len = a.len().min(b.len());
22        let mut dist_sq = 0.0;
23        for i in 0..len {
24            let diff = a[i] - b[i];
25            dist_sq += diff * diff;
26        }
27        // Use fixed-point arithmetic to avoid bitwise comparison issues
28        (dist_sq.sqrt() * 1_000_000.0) as u32
29    }
30}
31
32/// Persisted vector data
33#[derive(Serialize, Deserialize, Default)]
34struct VectorData {
35    entries: Vec<VectorEntry>,
36}
37
38#[derive(Serialize, Deserialize, Clone)]
39struct VectorEntry {
40    /// Unique identifier for this vector (could be URI or Hash)
41    key: String,
42    embedding: Vec<f32>,
43    /// Optional metadata associated with the vector (serialized as JSON string for compatibility)
44    #[serde(default)]
45    metadata_json: String,
46}
47
48/// Vector store using Local FastEmbed for embeddings
49pub struct VectorStore {
50    /// HNSW index for fast approximate nearest neighbor search
51    index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
52    /// Mapping from node ID (internal) to Key
53    id_to_key: Arc<RwLock<HashMap<usize, String>>>,
54    /// Mapping from Key to node ID (internal)
55    key_to_id: Arc<RwLock<HashMap<String, usize>>>,
56    /// Mapping from Key to Metadata (for fast retrieval)
57    key_to_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
58    /// Storage path for persistence
59    storage_path: Option<PathBuf>,
60    /// Local embedding model
61    model: TextEmbedding,
62    /// Vector dimensions
63    dimensions: usize,
64    /// Stored embeddings for persistence
65    embeddings: Arc<RwLock<Vec<VectorEntry>>>,
66    /// Number of unsaved changes
67    dirty_count: Arc<AtomicUsize>,
68    /// Threshold for auto-save
69    auto_save_threshold: usize,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73pub struct SearchResult {
74    /// The unique key
75    pub key: String,
76    pub score: f32,
77    /// Metadata (including original URI if applicable)
78    pub metadata: serde_json::Value,
79    // Helper to access URI from metadata if it exists, for backward compatibility
80    pub uri: String,
81}
82
83impl VectorStore {
84    /// Create a new vector store for a namespace
85    pub fn new(namespace: &str) -> Result<Self> {
86        // Try to get storage path from environment
87        let storage_path = std::env::var("GRAPH_STORAGE_PATH")
88            .ok()
89            .map(|p| PathBuf::from(p).join(namespace));
90
91        // Get dimensions from env or default
92        let dimensions = std::env::var("VECTOR_DIMENSIONS")
93            .ok()
94            .and_then(|s| s.parse().ok())
95            .unwrap_or(DEFAULT_DIMENSIONS);
96
97        // Initialize FastEmbed model
98        let mut model_opts =
99            InitOptions::new(EmbeddingModel::BGESmallENV15).with_show_download_progress(true);
100
101        if let Ok(cache_path) = std::env::var("FASTEMBED_CACHE_PATH") {
102            model_opts = model_opts.with_cache_dir(PathBuf::from(cache_path));
103        }
104
105        let model = TextEmbedding::try_new(model_opts)?;
106
107        // Create HNSW index
108        let mut index = Hnsw::new(Euclidian);
109        let mut id_to_key = HashMap::new();
110        let mut key_to_id = HashMap::new();
111        let mut key_to_metadata = HashMap::new();
112        let mut embeddings = Vec::new();
113
114        // Try to load persisted vectors
115        if let Some(ref path) = storage_path {
116            let vectors_json = path.join("vectors.json");
117
118            let loaded_data = if vectors_json.exists() {
119                match std::fs::read_to_string(&vectors_json) {
120                    Ok(content) => match serde_json::from_str::<VectorData>(&content) {
121                        Ok(data) => Some(data),
122                        Err(e) => {
123                            eprintln!("ERROR: Failed to parse vectors: {}", e);
124                            None
125                        }
126                    },
127                    Err(_) => None,
128                }
129            } else {
130                None
131            };
132
133            if let Some(data) = loaded_data {
134                let mut searcher = hnsw::Searcher::default();
135                for entry in data.entries {
136                    if entry.embedding.len() == dimensions {
137                        let id = index.insert(entry.embedding.clone(), &mut searcher);
138                        id_to_key.insert(id, entry.key.clone());
139                        key_to_id.insert(entry.key.clone(), id);
140                        
141                        let metadata = serde_json::from_str(&entry.metadata_json).unwrap_or(serde_json::Value::Null);
142                        key_to_metadata.insert(entry.key.clone(), metadata);
143                        embeddings.push(entry);
144                    }
145                }
146                eprintln!("Loaded {} vectors from disk", embeddings.len());
147            }
148        }
149
150        Ok(Self {
151            index: Arc::new(RwLock::new(index)),
152            id_to_key: Arc::new(RwLock::new(id_to_key)),
153            key_to_id: Arc::new(RwLock::new(key_to_id)),
154            key_to_metadata: Arc::new(RwLock::new(key_to_metadata)),
155            storage_path,
156            model,
157            dimensions,
158            embeddings: Arc::new(RwLock::new(embeddings)),
159            dirty_count: Arc::new(AtomicUsize::new(0)),
160            auto_save_threshold: DEFAULT_AUTO_SAVE_THRESHOLD,
161        })
162    }
163
164    /// Save vectors to disk (JSON format for robust cross-version compatibility)
165    fn save_vectors(&self) -> Result<()> {
166        if let Some(ref path) = self.storage_path {
167            std::fs::create_dir_all(path)?;
168
169            let (entries, current_dirty) = {
170                let guard = self.embeddings.read().unwrap();
171                (guard.clone(), self.dirty_count.load(Ordering::Relaxed))
172            };
173
174            let data = VectorData { entries };
175            let json = serde_json::to_string_pretty(&data)?;
176            std::fs::write(path.join("vectors.json"), json)?;
177
178            if current_dirty > 0 {
179                let _ = self.dirty_count.fetch_sub(current_dirty, Ordering::Relaxed);
180            }
181        }
182        Ok(())
183    }
184
185    pub fn flush(&self) -> Result<()> {
186        self.save_vectors()
187    }
188
189    pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
190        let embeddings = self.embed_batch(vec![text.to_string()]).await?;
191        Ok(embeddings[0].clone())
192    }
193
194    pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
195        if texts.is_empty() {
196            return Ok(Vec::new());
197        }
198        let embeddings = self.model.embed(texts, None)?;
199        Ok(embeddings)
200    }
201
202    pub async fn add(
203        &self,
204        key: &str,
205        content: &str,
206        metadata: serde_json::Value,
207    ) -> Result<usize> {
208        let results = self
209            .add_batch(vec![(key.to_string(), content.to_string(), metadata)])
210            .await?;
211        Ok(results[0])
212    }
213
214    pub async fn add_batch(
215        &self,
216        items: Vec<(String, String, serde_json::Value)>,
217    ) -> Result<Vec<usize>> {
218        let mut new_items = Vec::new();
219        let mut result_ids = vec![0; items.len()];
220        let mut new_indices = Vec::new();
221
222        {
223            let key_map = self.key_to_id.read().unwrap();
224            for (i, (key, content, _)) in items.iter().enumerate() {
225                if let Some(&id) = key_map.get(key) {
226                    result_ids[i] = id;
227                } else {
228                    new_items.push(content.clone());
229                    new_indices.push(i);
230                }
231            }
232        }
233
234        if new_items.is_empty() {
235            return Ok(result_ids);
236        }
237
238        let embeddings = self.embed_batch(new_items).await?;
239        let mut ids_to_add = Vec::new();
240        let mut searcher = hnsw::Searcher::default();
241
242        {
243            let mut index = self.index.write().unwrap();
244            let mut key_map = self.key_to_id.write().unwrap();
245            let mut id_map = self.id_to_key.write().unwrap();
246            let mut metadata_map = self.key_to_metadata.write().unwrap();
247            let mut embs = self.embeddings.write().unwrap();
248
249            for (i, embedding) in embeddings.into_iter().enumerate() {
250                let original_idx = new_indices[i];
251                let (key, _, metadata) = &items[original_idx];
252
253                if let Some(&id) = key_map.get(key) {
254                    result_ids[original_idx] = id;
255                    continue;
256                }
257
258                let id = index.insert(embedding.clone(), &mut searcher);
259                key_map.insert(key.clone(), id);
260                id_map.insert(id, key.clone());
261                metadata_map.insert(key.clone(), metadata.clone());
262
263                embs.push(VectorEntry {
264                    key: key.clone(),
265                    embedding,
266                    metadata_json: serde_json::to_string(metadata).unwrap_or_default(),
267                });
268
269                result_ids[original_idx] = id;
270                ids_to_add.push(id);
271            }
272        }
273
274        if !ids_to_add.is_empty() {
275            let count = self
276                .dirty_count
277                .fetch_add(ids_to_add.len(), Ordering::Relaxed);
278            if count + ids_to_add.len() >= self.auto_save_threshold {
279                let _ = self.save_vectors();
280            }
281        }
282
283        Ok(result_ids)
284    }
285
286    pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
287        let query_embedding = self.embed(query).await?;
288        let mut searcher = hnsw::Searcher::default();
289
290        let index = self.index.read().unwrap();
291        let len = index.len();
292        if len == 0 {
293            return Ok(Vec::new());
294        }
295
296        let k = k.min(len);
297        let ef = k.max(50);
298
299        let mut neighbors = vec![
300            space::Neighbor {
301                index: 0,
302                distance: u32::MAX
303            };
304            k
305        ];
306
307        let found_neighbors = index.nearest(&query_embedding, ef, &mut searcher, &mut neighbors);
308
309        let id_map = self.id_to_key.read().unwrap();
310        let metadata_map = self.key_to_metadata.read().unwrap();
311
312        let results: Vec<SearchResult> = found_neighbors
313            .iter()
314            .filter_map(|neighbor| {
315                id_map.get(&neighbor.index).map(|key| {
316                    let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
317                    let metadata = metadata_map
318                        .get(key)
319                        .cloned()
320                        .unwrap_or(serde_json::Value::Null);
321                    let uri = metadata
322                        .get("uri")
323                        .and_then(|v| v.as_str())
324                        .unwrap_or(key)
325                        .to_string();
326
327                    SearchResult {
328                        key: key.clone(),
329                        score: 1.0 / (1.0 + score_f32),
330                        metadata,
331                        uri,
332                    }
333                })
334            })
335            .collect();
336
337        Ok(results)
338    }
339
340    pub fn get_key(&self, id: usize) -> Option<String> {
341        self.id_to_key.read().unwrap().get(&id).cloned()
342    }
343
344    pub fn get_id(&self, key: &str) -> Option<usize> {
345        self.key_to_id.read().unwrap().get(key).copied()
346    }
347
348    pub fn len(&self) -> usize {
349        self.key_to_id.read().unwrap().len()
350    }
351
352    pub fn is_empty(&self) -> bool {
353        self.len() == 0
354    }
355
356    pub fn compact(&self) -> Result<usize> {
357        let embeddings = self.embeddings.read().unwrap();
358        let current_keys: std::collections::HashSet<_> =
359            self.key_to_id.read().unwrap().keys().cloned().collect();
360
361        if current_keys.is_empty() && !embeddings.is_empty() {
362            return Ok(0);
363        }
364
365        let active_entries: Vec<_> = embeddings
366            .iter()
367            .filter(|e| current_keys.contains(&e.key))
368            .cloned()
369            .collect();
370
371        let removed = embeddings.len() - active_entries.len();
372        if removed == 0 {
373            return Ok(0);
374        }
375
376        let mut new_index = hnsw::Hnsw::new(Euclidian);
377        let mut new_id_to_key = std::collections::HashMap::new();
378        let mut new_key_to_id = std::collections::HashMap::new();
379        let mut new_key_to_metadata = std::collections::HashMap::new();
380        let mut searcher = hnsw::Searcher::default();
381
382        for entry in &active_entries {
383            if entry.embedding.len() == self.dimensions {
384                let id = new_index.insert(entry.embedding.clone(), &mut searcher);
385                new_id_to_key.insert(id, entry.key.clone());
386                new_key_to_id.insert(entry.key.clone(), id);
387                let metadata = serde_json::from_str(&entry.metadata_json).unwrap_or(serde_json::Value::Null);
388                new_key_to_metadata.insert(entry.key.clone(), metadata);
389            }
390        }
391
392        *self.index.write().unwrap() = new_index;
393        *self.id_to_key.write().unwrap() = new_id_to_key;
394        *self.key_to_id.write().unwrap() = new_key_to_id;
395        *self.key_to_metadata.write().unwrap() = new_key_to_metadata;
396
397        drop(embeddings);
398        *self.embeddings.write().unwrap() = active_entries;
399        let _ = self.save_vectors();
400        Ok(removed)
401    }
402
403    pub fn remove(&self, key: &str) -> bool {
404        let mut key_map = self.key_to_id.write().unwrap();
405        let mut id_map = self.id_to_key.write().unwrap();
406        let mut metadata_map = self.key_to_metadata.write().unwrap();
407
408        if let Some(id) = key_map.remove(key) {
409            id_map.remove(&id);
410            metadata_map.remove(key);
411            true
412        } else {
413            false
414        }
415    }
416
417    pub fn stats(&self) -> (usize, usize, usize) {
418        let embeddings_count = self.embeddings.read().unwrap().len();
419        let active_count = self.key_to_id.read().unwrap().len();
420        let stale_count = embeddings_count.saturating_sub(active_count);
421        (active_count, stale_count, embeddings_count)
422    }
423}