Skip to main content

synapse_core/
vector_store.rs

1use crate::persistence::{load_bincode, save_bincode};
2use anyhow::Result;
3use hnsw::Hnsw;
4use rand_pcg::Pcg64;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::{Arc, RwLock};
11
12const HUGGINGFACE_API_URL: &str = "https://router.huggingface.co/hf-inference/models";
13const DEFAULT_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2"; // 384 dims, fast
14const DEFAULT_DIMENSIONS: usize = 384;
15const DEFAULT_AUTO_SAVE_THRESHOLD: usize = 100;
16
17/// Euclidean distance metric for HNSW
18#[derive(Default, Clone)]
19pub struct Euclidian;
20
21impl space::Metric<Vec<f32>> for Euclidian {
22    type Unit = u32;
23    fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
24        let len = a.len().min(b.len());
25        let mut dist_sq = 0.0;
26        for i in 0..len {
27            let diff = a[i] - b[i];
28            dist_sq += diff * diff;
29        }
30        // Use fixed-point arithmetic to avoid bitwise comparison issues
31        (dist_sq.sqrt() * 1_000_000.0) as u32
32    }
33}
34
35/// Persisted vector data
36#[derive(Serialize, Deserialize, Default)]
37struct VectorData {
38    entries: Vec<VectorEntry>,
39}
40
41#[derive(Serialize, Deserialize, Clone)]
42struct VectorEntry {
43    /// Unique identifier for this vector (could be URI or Hash)
44    key: String,
45    embedding: Vec<f32>,
46    /// Optional metadata associated with the vector
47    #[serde(default)]
48    metadata: serde_json::Value,
49}
50
51/// Vector store using HuggingFace Inference API for embeddings
52pub struct VectorStore {
53    /// HNSW index for fast approximate nearest neighbor search
54    index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
55    /// Mapping from node ID (internal) to Key
56    id_to_key: Arc<RwLock<HashMap<usize, String>>>,
57    /// Mapping from Key to node ID (internal)
58    key_to_id: Arc<RwLock<HashMap<String, usize>>>,
59    /// Mapping from Key to Metadata (for fast retrieval)
60    key_to_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
61    /// Storage path for persistence
62    storage_path: Option<PathBuf>,
63    /// HTTP client for HuggingFace API
64    client: Client,
65    /// HuggingFace API URL (configurable)
66    api_url: String,
67    /// HuggingFace API token (optional, for rate limits)
68    api_token: Option<String>,
69    /// Model name
70    model: String,
71    /// Vector dimensions
72    dimensions: usize,
73    /// Stored embeddings for persistence
74    embeddings: Arc<RwLock<Vec<VectorEntry>>>,
75    /// Number of unsaved changes
76    dirty_count: Arc<AtomicUsize>,
77    /// Threshold for auto-save
78    auto_save_threshold: usize,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82pub struct SearchResult {
83    /// The unique key
84    pub key: String,
85    pub score: f32,
86    /// Metadata (including original URI if applicable)
87    pub metadata: serde_json::Value,
88    // Helper to access URI from metadata if it exists, for backward compatibility
89    pub uri: String,
90}
91
92impl VectorStore {
93    /// Create a new vector store for a namespace
94    pub fn new(namespace: &str) -> Result<Self> {
95        // Try to get storage path from environment
96        let storage_path = std::env::var("GRAPH_STORAGE_PATH")
97            .ok()
98            .map(|p| PathBuf::from(p).join(namespace));
99
100        // Get dimensions from env or default
101        let dimensions = std::env::var("VECTOR_DIMENSIONS")
102            .ok()
103            .and_then(|s| s.parse().ok())
104            .unwrap_or(DEFAULT_DIMENSIONS);
105
106        // Create HNSW index with Euclidian metric
107        let mut index = Hnsw::new(Euclidian);
108        let mut id_to_key = HashMap::new();
109        let mut key_to_id = HashMap::new();
110        let mut key_to_metadata = HashMap::new();
111        let mut embeddings = Vec::new();
112
113        // Try to load persisted vectors
114        if let Some(ref path) = storage_path {
115            let vectors_bin = path.join("vectors.bin");
116            let vectors_json = path.join("vectors.json");
117
118            let loaded_data = if vectors_bin.exists() {
119                load_bincode::<VectorData>(&vectors_bin).ok()
120            } else if vectors_json.exists() {
121                // Fallback / Migration from JSON
122                let content = std::fs::read_to_string(&vectors_json).ok();
123                if let Some(content) = content {
124                    // Try new format first
125                    if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
126                        Some(data)
127                    } else {
128                        // Fallback: Try loading old format
129                        #[derive(Serialize, Deserialize)]
130                        struct OldVectorData {
131                            entries: Vec<OldVectorEntry>,
132                        }
133                        #[derive(Serialize, Deserialize)]
134                        struct OldVectorEntry {
135                            uri: String,
136                            embedding: Vec<f32>,
137                        }
138
139                        if let Ok(old_data) = serde_json::from_str::<OldVectorData>(&content) {
140                            let entries = old_data
141                                .entries
142                                .into_iter()
143                                .map(|old| VectorEntry {
144                                    key: old.uri.clone(),
145                                    embedding: old.embedding,
146                                    metadata: serde_json::json!({ "uri": old.uri }),
147                                })
148                                .collect();
149                            Some(VectorData { entries })
150                        } else {
151                            None
152                        }
153                    }
154                } else {
155                    None
156                }
157            } else {
158                None
159            };
160
161            if let Some(data) = loaded_data {
162                let mut searcher = hnsw::Searcher::default();
163                for entry in data.entries {
164                    if entry.embedding.len() == dimensions {
165                        let id = index.insert(entry.embedding.clone(), &mut searcher);
166                        id_to_key.insert(id, entry.key.clone());
167                        key_to_id.insert(entry.key.clone(), id);
168                        key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
169                        embeddings.push(entry);
170                    }
171                }
172                eprintln!(
173                    "Loaded {} vectors from disk (dim={})",
174                    embeddings.len(),
175                    dimensions
176                );
177            }
178        }
179
180        // Get API token from environment (optional)
181        let api_token = std::env::var("HUGGINGFACE_API_TOKEN").ok();
182
183        // Get API URL from environment (or default)
184        let api_url = std::env::var("HUGGINGFACE_API_URL")
185            .unwrap_or_else(|_| HUGGINGFACE_API_URL.to_string());
186
187        // Configured client with timeout
188        let client = Client::builder()
189            .timeout(std::time::Duration::from_secs(30))
190            .build()
191            .unwrap_or_else(|_| Client::new());
192
193        Ok(Self {
194            index: Arc::new(RwLock::new(index)),
195            id_to_key: Arc::new(RwLock::new(id_to_key)),
196            key_to_id: Arc::new(RwLock::new(key_to_id)),
197            key_to_metadata: Arc::new(RwLock::new(key_to_metadata)),
198            storage_path,
199            client,
200            api_url,
201            api_token,
202            model: DEFAULT_MODEL.to_string(),
203            dimensions,
204            embeddings: Arc::new(RwLock::new(embeddings)),
205            dirty_count: Arc::new(AtomicUsize::new(0)),
206            auto_save_threshold: DEFAULT_AUTO_SAVE_THRESHOLD,
207        })
208    }
209
210    /// Save vectors to disk
211    fn save_vectors(&self) -> Result<()> {
212        if let Some(ref path) = self.storage_path {
213            std::fs::create_dir_all(path)?;
214
215            // Hold read lock during serialization AND dirty count reset to avoid race condition.
216            // Wait, we need to read the dirty count at the time of snapshot.
217            // But we can't atomically read-and-subtract without a loop if we don't hold a write lock on dirty_count (which is Atomic).
218            // Actually, if we just subtract the value we *saw* before starting the save, new items added during save will just remain in the counter.
219            // But we are saving the *entire* vector list, so any items added during save (if that were possible with the lock held) would be included?
220            // `embeddings` is protected by RwLock.
221            // We take a read lock on `embeddings` to clone/serialize.
222            // While we hold the read lock, NO new items can be added (add_batch needs write lock).
223            // So the snapshot is consistent.
224            // The race happens *after* we drop the read lock and *before* we reset dirty_count.
225            // During that window, a writer could add items and increment dirty_count.
226            // If we then reset to 0, we lose those counts.
227            //
228            // Fix: Read dirty_count *while holding the read lock*.
229            // Since writers are blocked, dirty_count cannot change while we hold the lock.
230            // So `current_dirty` will be exactly the number of unsaved items included in `entries`.
231            // Then we perform the save (IO).
232            // Finally we subtract `current_dirty`.
233            // If new items come in during IO (after read lock drop), they increment counter.
234            // Subtracting `current_dirty` leaves those new increments intact. Correct.
235
236            let (entries, current_dirty) = {
237                let guard = self.embeddings.read().unwrap();
238                (guard.clone(), self.dirty_count.load(Ordering::Relaxed))
239            };
240
241            let data = VectorData { entries };
242            save_bincode(&path.join("vectors.bin"), &data)?;
243
244            if current_dirty > 0 {
245                let _ = self.dirty_count.fetch_sub(current_dirty, Ordering::Relaxed);
246            }
247        }
248        Ok(())
249    }
250
251    /// Force save to disk
252    pub fn flush(&self) -> Result<()> {
253        self.save_vectors()
254    }
255
256    /// Generate embedding for a text using HuggingFace Inference API
257    /// (Mocked if MOCK_EMBEDDINGS is set)
258    pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
259        if std::env::var("MOCK_EMBEDDINGS").is_ok() {
260            // Return random embedding for testing
261            use rand::Rng;
262            let mut rng = rand::rng();
263            let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
264            return Ok(vec);
265        }
266
267        let embeddings = self.embed_batch(vec![text.to_string()]).await?;
268        Ok(embeddings[0].clone())
269    }
270
271    /// Generate embeddings for multiple texts using HuggingFace Inference API
272    pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
273        if texts.is_empty() {
274            return Ok(Vec::new());
275        }
276
277        if std::env::var("MOCK_EMBEDDINGS").is_ok() {
278            use rand::Rng;
279            let mut rng = rand::rng();
280            let mut results = Vec::new();
281            for _ in 0..texts.len() {
282                let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
283                results.push(vec);
284            }
285            return Ok(results);
286        }
287
288        let url = format!(
289            "{}/{}/pipeline/feature-extraction",
290            self.api_url, self.model
291        );
292
293        // HuggingFace accepts array of strings for inputs
294        let mut request = self.client.post(&url).json(&serde_json::json!({
295            "inputs": texts,
296        }));
297
298        // Add auth token if available
299        if let Some(ref token) = self.api_token {
300            request = request.header("Authorization", format!("Bearer {}", token));
301        }
302
303        let response = request.send().await?;
304
305        if !response.status().is_success() {
306            let error_text = response.text().await?;
307            anyhow::bail!("HuggingFace API error: {}", error_text);
308        }
309
310        let response_json: serde_json::Value = response.json().await?;
311        let mut results = Vec::new();
312
313        if let Some(arr) = response_json.as_array() {
314            for item in arr {
315                let vec: Vec<f32> = serde_json::from_value(item.clone())
316                    .map_err(|e| anyhow::anyhow!("Failed to parse embedding: {}", e))?;
317
318                if vec.len() != self.dimensions {
319                    anyhow::bail!("Expected {} dimensions, got {}", self.dimensions, vec.len());
320                }
321                results.push(vec);
322            }
323        } else {
324            // Handle case where we sent 1 text and got single flat array [0.1, ...]
325            if texts.len() == 1 {
326                if let Ok(vec) = serde_json::from_value::<Vec<f32>>(response_json) {
327                    if vec.len() == self.dimensions {
328                        results.push(vec);
329                    }
330                }
331            }
332        }
333
334        if results.len() != texts.len() {
335            anyhow::bail!("Expected {} embeddings, got {}", texts.len(), results.len());
336        }
337
338        Ok(results)
339    }
340
341    /// Add a key with its text content to the index
342    pub async fn add(
343        &self,
344        key: &str,
345        content: &str,
346        metadata: serde_json::Value,
347    ) -> Result<usize> {
348        let results = self
349            .add_batch(vec![(key.to_string(), content.to_string(), metadata)])
350            .await?;
351        Ok(results[0])
352    }
353
354    /// Add multiple keys with their text content to the index
355    pub async fn add_batch(
356        &self,
357        items: Vec<(String, String, serde_json::Value)>,
358    ) -> Result<Vec<usize>> {
359        // Filter out existing keys
360        let mut new_items = Vec::new();
361        let mut result_ids = vec![0; items.len()];
362        let mut new_indices = Vec::new(); // Map index in new_items to index in items
363
364        {
365            let key_map = self.key_to_id.read().unwrap();
366            for (i, (key, content, _)) in items.iter().enumerate() {
367                if let Some(&id) = key_map.get(key) {
368                    result_ids[i] = id;
369                } else {
370                    new_items.push(content.clone());
371                    new_indices.push(i);
372                }
373            }
374        }
375
376        if new_items.is_empty() {
377            return Ok(result_ids);
378        }
379
380        // Generate embeddings via HuggingFace API
381        let embeddings = self.embed_batch(new_items).await?;
382
383        let mut ids_to_add = Vec::new();
384        let mut searcher = hnsw::Searcher::default();
385
386        // Add to HNSW index and maps
387        {
388            let mut index = self.index.write().unwrap();
389            let mut key_map = self.key_to_id.write().unwrap();
390            let mut id_map = self.id_to_key.write().unwrap();
391            let mut metadata_map = self.key_to_metadata.write().unwrap();
392            let mut embs = self.embeddings.write().unwrap();
393
394            for (i, embedding) in embeddings.into_iter().enumerate() {
395                let original_idx = new_indices[i];
396                let (key, _, metadata) = &items[original_idx];
397
398                // Double check if inserted in race condition
399                if let Some(&id) = key_map.get(key) {
400                    result_ids[original_idx] = id;
401                    continue;
402                }
403
404                let id = index.insert(embedding.clone(), &mut searcher);
405                key_map.insert(key.clone(), id);
406                id_map.insert(id, key.clone());
407                metadata_map.insert(key.clone(), metadata.clone());
408
409                embs.push(VectorEntry {
410                    key: key.clone(),
411                    embedding,
412                    metadata: metadata.clone(),
413                });
414
415                result_ids[original_idx] = id;
416                ids_to_add.push(id);
417            }
418        }
419
420        if !ids_to_add.is_empty() {
421            let count = self
422                .dirty_count
423                .fetch_add(ids_to_add.len(), Ordering::Relaxed);
424            if count + ids_to_add.len() >= self.auto_save_threshold {
425                let _ = self.save_vectors();
426            }
427        }
428
429        Ok(result_ids)
430    }
431
432    /// Search for similar vectors
433    pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
434        // Generate query embedding via HuggingFace API
435        let query_embedding = self.embed(query).await?;
436
437        // Search HNSW index
438        let mut searcher = hnsw::Searcher::default();
439
440        let index = self.index.read().unwrap();
441        let len = index.len();
442        if len == 0 {
443            return Ok(Vec::new());
444        }
445
446        let k = k.min(len);
447        let ef = k.max(50); // Ensure ef is at least k, but usually 50+
448
449        let mut neighbors = vec![
450            space::Neighbor {
451                index: 0,
452                distance: u32::MAX
453            };
454            k
455        ];
456
457        let found_neighbors = index.nearest(&query_embedding, ef, &mut searcher, &mut neighbors);
458
459        // Convert to results
460        let id_map = self.id_to_key.read().unwrap();
461        let metadata_map = self.key_to_metadata.read().unwrap();
462
463        let results: Vec<SearchResult> = found_neighbors
464            .iter()
465            .filter_map(|neighbor| {
466                id_map.get(&neighbor.index).map(|key| {
467                    let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
468
469                    let metadata = metadata_map
470                        .get(key)
471                        .cloned()
472                        .unwrap_or(serde_json::Value::Null);
473
474                    let uri = metadata
475                        .get("uri")
476                        .and_then(|v| v.as_str())
477                        .unwrap_or(key)
478                        .to_string();
479
480                    SearchResult {
481                        key: key.clone(),
482                        score: 1.0 / (1.0 + score_f32),
483                        metadata,
484                        uri,
485                    }
486                })
487            })
488            .collect();
489
490        Ok(results)
491    }
492
493    pub fn get_key(&self, id: usize) -> Option<String> {
494        self.id_to_key.read().unwrap().get(&id).cloned()
495    }
496
497    pub fn get_id(&self, key: &str) -> Option<usize> {
498        self.key_to_id.read().unwrap().get(key).copied()
499    }
500
501    pub fn len(&self) -> usize {
502        self.key_to_id.read().unwrap().len()
503    }
504
505    pub fn is_empty(&self) -> bool {
506        self.len() == 0
507    }
508
509    /// Compaction: rebuild index from stored embeddings, removing stale entries
510    pub fn compact(&self) -> Result<usize> {
511        let embeddings = self.embeddings.read().unwrap();
512        let current_keys: std::collections::HashSet<_> =
513            self.key_to_id.read().unwrap().keys().cloned().collect();
514
515        if current_keys.is_empty() && !embeddings.is_empty() {
516            return Ok(0);
517        }
518
519        // Filter to only current Keys
520        let active_entries: Vec<_> = embeddings
521            .iter()
522            .filter(|e| current_keys.contains(&e.key))
523            .cloned()
524            .collect();
525
526        let removed = embeddings.len() - active_entries.len();
527
528        if removed == 0 {
529            return Ok(0);
530        }
531
532        // Rebuild index
533        let mut new_index = hnsw::Hnsw::new(Euclidian);
534        let mut new_id_to_key = std::collections::HashMap::new();
535        let mut new_key_to_id = std::collections::HashMap::new();
536        let mut new_key_to_metadata = std::collections::HashMap::new();
537        let mut searcher = hnsw::Searcher::default();
538
539        for entry in &active_entries {
540            if entry.embedding.len() == self.dimensions {
541                let id = new_index.insert(entry.embedding.clone(), &mut searcher);
542                new_id_to_key.insert(id, entry.key.clone());
543                new_key_to_id.insert(entry.key.clone(), id);
544                new_key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
545            }
546        }
547
548        // Swap in new index
549        *self.index.write().unwrap() = new_index;
550        *self.id_to_key.write().unwrap() = new_id_to_key;
551        *self.key_to_id.write().unwrap() = new_key_to_id;
552        *self.key_to_metadata.write().unwrap() = new_key_to_metadata;
553
554        // Update embeddings (drop takes write lock)
555        drop(embeddings);
556        *self.embeddings.write().unwrap() = active_entries;
557
558        let _ = self.save_vectors();
559
560        Ok(removed)
561    }
562
563    /// Remove a Key from the vector store
564    pub fn remove(&self, key: &str) -> bool {
565        let mut key_map = self.key_to_id.write().unwrap();
566        let mut id_map = self.id_to_key.write().unwrap();
567        let mut metadata_map = self.key_to_metadata.write().unwrap();
568
569        if let Some(id) = key_map.remove(key) {
570            id_map.remove(&id);
571            metadata_map.remove(key);
572            // Note: actual index entry remains until compaction
573            true
574        } else {
575            false
576        }
577    }
578
579    /// Get storage stats
580    pub fn stats(&self) -> (usize, usize, usize) {
581        let embeddings_count = self.embeddings.read().unwrap().len();
582        let active_count = self.key_to_id.read().unwrap().len();
583        let stale_count = embeddings_count.saturating_sub(active_count);
584        (active_count, stale_count, embeddings_count)
585    }
586}