Skip to main content

synapse_core/
vector_store.rs

1use crate::persistence::{load_bincode, save_bincode};
2use anyhow::Result;
3use hnsw::Hnsw;
4use rand_pcg::Pcg64;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::{Arc, RwLock};
11
12const HUGGINGFACE_API_URL: &str = "https://router.huggingface.co/hf-inference/models";
13const DEFAULT_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2"; // 384 dims, fast
14const DEFAULT_DIMENSIONS: usize = 384;
15const DEFAULT_AUTO_SAVE_THRESHOLD: usize = 100;
16
17/// Euclidean distance metric for HNSW
18#[derive(Default, Clone)]
19pub struct Euclidian;
20
21impl space::Metric<Vec<f32>> for Euclidian {
22    type Unit = u32;
23    fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
24        let len = a.len().min(b.len());
25        let mut dist_sq = 0.0;
26        for i in 0..len {
27            let diff = a[i] - b[i];
28            dist_sq += diff * diff;
29        }
30        // Use fixed-point arithmetic to avoid bitwise comparison issues
31        (dist_sq.sqrt() * 1_000_000.0) as u32
32    }
33}
34
35/// Persisted vector data
36#[derive(Serialize, Deserialize, Default)]
37struct VectorData {
38    entries: Vec<VectorEntry>,
39}
40
41#[derive(Serialize, Deserialize, Clone)]
42struct VectorEntry {
43    /// Unique identifier for this vector (could be URI or Hash)
44    key: String,
45    embedding: Vec<f32>,
46    /// Optional metadata associated with the vector
47    #[serde(default)]
48    metadata: serde_json::Value,
49}
50
51/// Vector store using HuggingFace Inference API for embeddings
52pub struct VectorStore {
53    /// HNSW index for fast approximate nearest neighbor search
54    index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
55    /// Mapping from node ID (internal) to Key
56    id_to_key: Arc<RwLock<HashMap<usize, String>>>,
57    /// Mapping from Key to node ID (internal)
58    key_to_id: Arc<RwLock<HashMap<String, usize>>>,
59    /// Mapping from Key to Metadata (for fast retrieval)
60    key_to_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
61    /// Storage path for persistence
62    storage_path: Option<PathBuf>,
63    /// HTTP client for HuggingFace API
64    client: Client,
65    /// HuggingFace API token (optional, for rate limits)
66    api_token: Option<String>,
67    /// Model name
68    model: String,
69    /// Vector dimensions
70    dimensions: usize,
71    /// Stored embeddings for persistence
72    embeddings: Arc<RwLock<Vec<VectorEntry>>>,
73    /// Number of unsaved changes
74    dirty_count: Arc<AtomicUsize>,
75    /// Threshold for auto-save
76    auto_save_threshold: usize,
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80pub struct SearchResult {
81    /// The unique key
82    pub key: String,
83    pub score: f32,
84    /// Metadata (including original URI if applicable)
85    pub metadata: serde_json::Value,
86    // Helper to access URI from metadata if it exists, for backward compatibility
87    pub uri: String,
88}
89
90impl VectorStore {
91    /// Create a new vector store for a namespace
92    pub fn new(namespace: &str) -> Result<Self> {
93        // Try to get storage path from environment
94        let storage_path = std::env::var("GRAPH_STORAGE_PATH")
95            .ok()
96            .map(|p| PathBuf::from(p).join(namespace));
97
98        // Get dimensions from env or default
99        let dimensions = std::env::var("VECTOR_DIMENSIONS")
100            .ok()
101            .and_then(|s| s.parse().ok())
102            .unwrap_or(DEFAULT_DIMENSIONS);
103
104        // Create HNSW index with Euclidian metric
105        let mut index = Hnsw::new(Euclidian);
106        let mut id_to_key = HashMap::new();
107        let mut key_to_id = HashMap::new();
108        let mut key_to_metadata = HashMap::new();
109        let mut embeddings = Vec::new();
110
111        // Try to load persisted vectors
112        if let Some(ref path) = storage_path {
113            let vectors_bin = path.join("vectors.bin");
114            let vectors_json = path.join("vectors.json");
115
116            let loaded_data = if vectors_bin.exists() {
117                load_bincode::<VectorData>(&vectors_bin).ok()
118            } else if vectors_json.exists() {
119                // Fallback / Migration from JSON
120                let content = std::fs::read_to_string(&vectors_json).ok();
121                if let Some(content) = content {
122                    // Try new format first
123                    if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
124                        Some(data)
125                    } else {
126                        // Fallback: Try loading old format
127                        #[derive(Serialize, Deserialize)]
128                        struct OldVectorData {
129                            entries: Vec<OldVectorEntry>,
130                        }
131                        #[derive(Serialize, Deserialize)]
132                        struct OldVectorEntry {
133                            uri: String,
134                            embedding: Vec<f32>,
135                        }
136
137                        if let Ok(old_data) = serde_json::from_str::<OldVectorData>(&content) {
138                            let entries = old_data
139                                .entries
140                                .into_iter()
141                                .map(|old| VectorEntry {
142                                    key: old.uri.clone(),
143                                    embedding: old.embedding,
144                                    metadata: serde_json::json!({ "uri": old.uri }),
145                                })
146                                .collect();
147                            Some(VectorData { entries })
148                        } else {
149                            None
150                        }
151                    }
152                } else {
153                    None
154                }
155            } else {
156                None
157            };
158
159            if let Some(data) = loaded_data {
160                let mut searcher = hnsw::Searcher::default();
161                for entry in data.entries {
162                    if entry.embedding.len() == dimensions {
163                        let id = index.insert(entry.embedding.clone(), &mut searcher);
164                        id_to_key.insert(id, entry.key.clone());
165                        key_to_id.insert(entry.key.clone(), id);
166                        key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
167                        embeddings.push(entry);
168                    }
169                }
170                eprintln!(
171                    "Loaded {} vectors from disk (dim={})",
172                    embeddings.len(),
173                    dimensions
174                );
175            }
176        }
177
178        // Get API token from environment (optional)
179        let api_token = std::env::var("HUGGINGFACE_API_TOKEN").ok();
180
181        // Configured client with timeout
182        let client = Client::builder()
183            .timeout(std::time::Duration::from_secs(30))
184            .build()
185            .unwrap_or_else(|_| Client::new());
186
187        Ok(Self {
188            index: Arc::new(RwLock::new(index)),
189            id_to_key: Arc::new(RwLock::new(id_to_key)),
190            key_to_id: Arc::new(RwLock::new(key_to_id)),
191            key_to_metadata: Arc::new(RwLock::new(key_to_metadata)),
192            storage_path,
193            client,
194            api_token,
195            model: DEFAULT_MODEL.to_string(),
196            dimensions,
197            embeddings: Arc::new(RwLock::new(embeddings)),
198            dirty_count: Arc::new(AtomicUsize::new(0)),
199            auto_save_threshold: DEFAULT_AUTO_SAVE_THRESHOLD,
200        })
201    }
202
203    /// Save vectors to disk
204    fn save_vectors(&self) -> Result<()> {
205        if let Some(ref path) = self.storage_path {
206            std::fs::create_dir_all(path)?;
207
208            // Hold read lock during serialization AND dirty count reset to avoid race condition.
209            // Wait, we need to read the dirty count at the time of snapshot.
210            // But we can't atomically read-and-subtract without a loop if we don't hold a write lock on dirty_count (which is Atomic).
211            // Actually, if we just subtract the value we *saw* before starting the save, new items added during save will just remain in the counter.
212            // But we are saving the *entire* vector list, so any items added during save (if that were possible with the lock held) would be included?
213            // `embeddings` is protected by RwLock.
214            // We take a read lock on `embeddings` to clone/serialize.
215            // While we hold the read lock, NO new items can be added (add_batch needs write lock).
216            // So the snapshot is consistent.
217            // The race happens *after* we drop the read lock and *before* we reset dirty_count.
218            // During that window, a writer could add items and increment dirty_count.
219            // If we then reset to 0, we lose those counts.
220            //
221            // Fix: Read dirty_count *while holding the read lock*.
222            // Since writers are blocked, dirty_count cannot change while we hold the lock.
223            // So `current_dirty` will be exactly the number of unsaved items included in `entries`.
224            // Then we perform the save (IO).
225            // Finally we subtract `current_dirty`.
226            // If new items come in during IO (after read lock drop), they increment counter.
227            // Subtracting `current_dirty` leaves those new increments intact. Correct.
228
229            let (entries, current_dirty) = {
230                let guard = self.embeddings.read().unwrap();
231                (guard.clone(), self.dirty_count.load(Ordering::Relaxed))
232            };
233
234            let data = VectorData { entries };
235            save_bincode(&path.join("vectors.bin"), &data)?;
236
237            if current_dirty > 0 {
238                let _ = self.dirty_count.fetch_sub(current_dirty, Ordering::Relaxed);
239            }
240        }
241        Ok(())
242    }
243
244    /// Force save to disk
245    pub fn flush(&self) -> Result<()> {
246        self.save_vectors()
247    }
248
249    /// Generate embedding for a text using HuggingFace Inference API
250    /// (Mocked if MOCK_EMBEDDINGS is set)
251    pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
252        if std::env::var("MOCK_EMBEDDINGS").is_ok() {
253            // Return random embedding for testing
254            use rand::Rng;
255            let mut rng = rand::rng();
256            let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
257            return Ok(vec);
258        }
259
260        let embeddings = self.embed_batch(vec![text.to_string()]).await?;
261        Ok(embeddings[0].clone())
262    }
263
264    /// Generate embeddings for multiple texts using HuggingFace Inference API
265    pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
266        if texts.is_empty() {
267            return Ok(Vec::new());
268        }
269
270        if std::env::var("MOCK_EMBEDDINGS").is_ok() {
271            use rand::Rng;
272            let mut rng = rand::rng();
273            let mut results = Vec::new();
274            for _ in 0..texts.len() {
275                let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
276                results.push(vec);
277            }
278            return Ok(results);
279        }
280
281        let url = format!(
282            "{}/{}/pipeline/feature-extraction",
283            HUGGINGFACE_API_URL, self.model
284        );
285
286        // HuggingFace accepts array of strings for inputs
287        let mut request = self.client.post(&url).json(&serde_json::json!({
288            "inputs": texts,
289        }));
290
291        // Add auth token if available
292        if let Some(ref token) = self.api_token {
293            request = request.header("Authorization", format!("Bearer {}", token));
294        }
295
296        let response = request.send().await?;
297
298        if !response.status().is_success() {
299            let error_text = response.text().await?;
300            anyhow::bail!("HuggingFace API error: {}", error_text);
301        }
302
303        let response_json: serde_json::Value = response.json().await?;
304        let mut results = Vec::new();
305
306        if let Some(arr) = response_json.as_array() {
307            for item in arr {
308                let vec: Vec<f32> = serde_json::from_value(item.clone())
309                    .map_err(|e| anyhow::anyhow!("Failed to parse embedding: {}", e))?;
310
311                if vec.len() != self.dimensions {
312                    anyhow::bail!("Expected {} dimensions, got {}", self.dimensions, vec.len());
313                }
314                results.push(vec);
315            }
316        } else {
317            // Handle case where we sent 1 text and got single flat array [0.1, ...]
318            if texts.len() == 1 {
319                if let Ok(vec) = serde_json::from_value::<Vec<f32>>(response_json) {
320                    if vec.len() == self.dimensions {
321                        results.push(vec);
322                    }
323                }
324            }
325        }
326
327        if results.len() != texts.len() {
328            anyhow::bail!("Expected {} embeddings, got {}", texts.len(), results.len());
329        }
330
331        Ok(results)
332    }
333
334    /// Add a key with its text content to the index
335    pub async fn add(
336        &self,
337        key: &str,
338        content: &str,
339        metadata: serde_json::Value,
340    ) -> Result<usize> {
341        let results = self
342            .add_batch(vec![(key.to_string(), content.to_string(), metadata)])
343            .await?;
344        Ok(results[0])
345    }
346
347    /// Add multiple keys with their text content to the index
348    pub async fn add_batch(
349        &self,
350        items: Vec<(String, String, serde_json::Value)>,
351    ) -> Result<Vec<usize>> {
352        // Filter out existing keys
353        let mut new_items = Vec::new();
354        let mut result_ids = vec![0; items.len()];
355        let mut new_indices = Vec::new(); // Map index in new_items to index in items
356
357        {
358            let key_map = self.key_to_id.read().unwrap();
359            for (i, (key, content, _)) in items.iter().enumerate() {
360                if let Some(&id) = key_map.get(key) {
361                    result_ids[i] = id;
362                } else {
363                    new_items.push(content.clone());
364                    new_indices.push(i);
365                }
366            }
367        }
368
369        if new_items.is_empty() {
370            return Ok(result_ids);
371        }
372
373        // Generate embeddings via HuggingFace API
374        let embeddings = self.embed_batch(new_items).await?;
375
376        let mut ids_to_add = Vec::new();
377        let mut searcher = hnsw::Searcher::default();
378
379        // Add to HNSW index and maps
380        {
381            let mut index = self.index.write().unwrap();
382            let mut key_map = self.key_to_id.write().unwrap();
383            let mut id_map = self.id_to_key.write().unwrap();
384            let mut metadata_map = self.key_to_metadata.write().unwrap();
385            let mut embs = self.embeddings.write().unwrap();
386
387            for (i, embedding) in embeddings.into_iter().enumerate() {
388                let original_idx = new_indices[i];
389                let (key, _, metadata) = &items[original_idx];
390
391                // Double check if inserted in race condition
392                if let Some(&id) = key_map.get(key) {
393                    result_ids[original_idx] = id;
394                    continue;
395                }
396
397                let id = index.insert(embedding.clone(), &mut searcher);
398                key_map.insert(key.clone(), id);
399                id_map.insert(id, key.clone());
400                metadata_map.insert(key.clone(), metadata.clone());
401
402                embs.push(VectorEntry {
403                    key: key.clone(),
404                    embedding,
405                    metadata: metadata.clone(),
406                });
407
408                result_ids[original_idx] = id;
409                ids_to_add.push(id);
410            }
411        }
412
413        if !ids_to_add.is_empty() {
414            let count = self
415                .dirty_count
416                .fetch_add(ids_to_add.len(), Ordering::Relaxed);
417            if count + ids_to_add.len() >= self.auto_save_threshold {
418                let _ = self.save_vectors();
419            }
420        }
421
422        Ok(result_ids)
423    }
424
425    /// Search for similar vectors
426    pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
427        // Generate query embedding via HuggingFace API
428        let query_embedding = self.embed(query).await?;
429
430        // Search HNSW index
431        let mut searcher = hnsw::Searcher::default();
432
433        let index = self.index.read().unwrap();
434        let len = index.len();
435        if len == 0 {
436            return Ok(Vec::new());
437        }
438
439        let k = k.min(len);
440        let ef = k.max(50); // Ensure ef is at least k, but usually 50+
441
442        let mut neighbors = vec![
443            space::Neighbor {
444                index: 0,
445                distance: u32::MAX
446            };
447            k
448        ];
449
450        let found_neighbors = index.nearest(&query_embedding, ef, &mut searcher, &mut neighbors);
451
452        // Convert to results
453        let id_map = self.id_to_key.read().unwrap();
454        let metadata_map = self.key_to_metadata.read().unwrap();
455
456        let results: Vec<SearchResult> = found_neighbors
457            .iter()
458            .filter_map(|neighbor| {
459                id_map.get(&neighbor.index).map(|key| {
460                    let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
461
462                    let metadata = metadata_map
463                        .get(key)
464                        .cloned()
465                        .unwrap_or(serde_json::Value::Null);
466
467                    let uri = metadata
468                        .get("uri")
469                        .and_then(|v| v.as_str())
470                        .unwrap_or(key)
471                        .to_string();
472
473                    SearchResult {
474                        key: key.clone(),
475                        score: 1.0 / (1.0 + score_f32),
476                        metadata,
477                        uri,
478                    }
479                })
480            })
481            .collect();
482
483        Ok(results)
484    }
485
486    pub fn get_key(&self, id: usize) -> Option<String> {
487        self.id_to_key.read().unwrap().get(&id).cloned()
488    }
489
490    pub fn get_id(&self, key: &str) -> Option<usize> {
491        self.key_to_id.read().unwrap().get(key).copied()
492    }
493
494    pub fn len(&self) -> usize {
495        self.key_to_id.read().unwrap().len()
496    }
497
498    pub fn is_empty(&self) -> bool {
499        self.len() == 0
500    }
501
502    /// Compaction: rebuild index from stored embeddings, removing stale entries
503    pub fn compact(&self) -> Result<usize> {
504        let embeddings = self.embeddings.read().unwrap();
505        let current_keys: std::collections::HashSet<_> =
506            self.key_to_id.read().unwrap().keys().cloned().collect();
507
508        if current_keys.is_empty() && !embeddings.is_empty() {
509            return Ok(0);
510        }
511
512        // Filter to only current Keys
513        let active_entries: Vec<_> = embeddings
514            .iter()
515            .filter(|e| current_keys.contains(&e.key))
516            .cloned()
517            .collect();
518
519        let removed = embeddings.len() - active_entries.len();
520
521        if removed == 0 {
522            return Ok(0);
523        }
524
525        // Rebuild index
526        let mut new_index = hnsw::Hnsw::new(Euclidian);
527        let mut new_id_to_key = std::collections::HashMap::new();
528        let mut new_key_to_id = std::collections::HashMap::new();
529        let mut new_key_to_metadata = std::collections::HashMap::new();
530        let mut searcher = hnsw::Searcher::default();
531
532        for entry in &active_entries {
533            if entry.embedding.len() == self.dimensions {
534                let id = new_index.insert(entry.embedding.clone(), &mut searcher);
535                new_id_to_key.insert(id, entry.key.clone());
536                new_key_to_id.insert(entry.key.clone(), id);
537                new_key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
538            }
539        }
540
541        // Swap in new index
542        *self.index.write().unwrap() = new_index;
543        *self.id_to_key.write().unwrap() = new_id_to_key;
544        *self.key_to_id.write().unwrap() = new_key_to_id;
545        *self.key_to_metadata.write().unwrap() = new_key_to_metadata;
546
547        // Update embeddings (drop takes write lock)
548        drop(embeddings);
549        *self.embeddings.write().unwrap() = active_entries;
550
551        let _ = self.save_vectors();
552
553        Ok(removed)
554    }
555
556    /// Remove a Key from the vector store
557    pub fn remove(&self, key: &str) -> bool {
558        let mut key_map = self.key_to_id.write().unwrap();
559        let mut id_map = self.id_to_key.write().unwrap();
560        let mut metadata_map = self.key_to_metadata.write().unwrap();
561
562        if let Some(id) = key_map.remove(key) {
563            id_map.remove(&id);
564            metadata_map.remove(key);
565            // Note: actual index entry remains until compaction
566            true
567        } else {
568            false
569        }
570    }
571
572    /// Get storage stats
573    pub fn stats(&self) -> (usize, usize, usize) {
574        let embeddings_count = self.embeddings.read().unwrap().len();
575        let active_count = self.key_to_id.read().unwrap().len();
576        let stale_count = embeddings_count.saturating_sub(active_count);
577        (active_count, stale_count, embeddings_count)
578    }
579}