Skip to main content

synapse_core/
vector_store.rs

1use crate::persistence::{load_bincode, save_bincode};
2use anyhow::Result;
3use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
4use hnsw::Hnsw;
5use rand_pcg::Pcg64;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::PathBuf;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::{Arc, Mutex, RwLock};
12
13const HUGGINGFACE_API_URL: &str = "https://router.huggingface.co/hf-inference/models";
14const DEFAULT_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2"; // 384 dims, fast
15const DEFAULT_DIMENSIONS: usize = 384;
16const DEFAULT_AUTO_SAVE_THRESHOLD: usize = 100;
17
18/// Euclidean distance metric for HNSW
19#[derive(Default, Clone)]
20pub struct Euclidian;
21
22impl space::Metric<Vec<f32>> for Euclidian {
23    type Unit = u32;
24    fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
25        let len = a.len().min(b.len());
26        let mut dist_sq = 0.0;
27        for i in 0..len {
28            let diff = a[i] - b[i];
29            dist_sq += diff * diff;
30        }
31        // Use fixed-point arithmetic to avoid bitwise comparison issues
32        (dist_sq.sqrt() * 1_000_000.0) as u32
33    }
34}
35
36/// Persisted vector data
37#[derive(Serialize, Deserialize, Default)]
38struct VectorData {
39    entries: Vec<VectorEntry>,
40}
41
42#[derive(Serialize, Deserialize, Clone)]
43struct VectorEntry {
44    /// Unique identifier for this vector (could be URI or Hash)
45    key: String,
46    embedding: Vec<f32>,
47    /// Optional metadata associated with the vector
48    #[serde(default)]
49    metadata: serde_json::Value,
50}
51
52/// Vector store using HuggingFace Inference API for embeddings
53pub struct VectorStore {
54    /// HNSW index for fast approximate nearest neighbor search
55    index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
56    /// Mapping from node ID (internal) to Key
57    id_to_key: Arc<RwLock<HashMap<usize, String>>>,
58    /// Mapping from Key to node ID (internal)
59    key_to_id: Arc<RwLock<HashMap<String, usize>>>,
60    /// Mapping from Key to Metadata (for fast retrieval)
61    key_to_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
62    /// Storage path for persistence
63    storage_path: Option<PathBuf>,
64    /// HTTP client for HuggingFace API
65    client: Client,
66    /// HuggingFace API URL (configurable)
67    api_url: String,
68    /// HuggingFace API token (optional, for rate limits)
69    api_token: Option<String>,
70    /// Model name
71    model: String,
72    /// Vector dimensions
73    dimensions: usize,
74    /// Stored embeddings for persistence
75    embeddings: Arc<RwLock<Vec<VectorEntry>>>,
76    /// Number of unsaved changes
77    dirty_count: Arc<AtomicUsize>,
78    /// Threshold for auto-save
79    auto_save_threshold: usize,
80    /// Local embedding model (optional)
81    local_model: Option<Arc<Mutex<TextEmbedding>>>,
82}
83
84#[derive(Debug, Serialize, Deserialize)]
85pub struct SearchResult {
86    /// The unique key
87    pub key: String,
88    pub score: f32,
89    /// Metadata (including original URI if applicable)
90    pub metadata: serde_json::Value,
91    // Helper to access URI from metadata if it exists, for backward compatibility
92    pub uri: String,
93}
94
95impl VectorStore {
96    /// Create a new vector store for a namespace
97    pub fn new(namespace: &str) -> Result<Self> {
98        // Try to get storage path from environment
99        let storage_path = std::env::var("GRAPH_STORAGE_PATH")
100            .ok()
101            .map(|p| PathBuf::from(p).join(namespace));
102
103        // Get dimensions from env or default
104        let dimensions = std::env::var("VECTOR_DIMENSIONS")
105            .ok()
106            .and_then(|s| s.parse().ok())
107            .unwrap_or(DEFAULT_DIMENSIONS);
108
109        // Create HNSW index with Euclidian metric
110        let mut index = Hnsw::new(Euclidian);
111        let mut id_to_key = HashMap::new();
112        let mut key_to_id = HashMap::new();
113        let mut key_to_metadata = HashMap::new();
114        let mut embeddings = Vec::new();
115
116        // Try to load persisted vectors
117        if let Some(ref path) = storage_path {
118            let vectors_bin = path.join("vectors.bin");
119            let vectors_json = path.join("vectors.json");
120
121            let loaded_data = if vectors_bin.exists() {
122                load_bincode::<VectorData>(&vectors_bin).ok()
123            } else if vectors_json.exists() {
124                // Fallback / Migration from JSON
125                let content = std::fs::read_to_string(&vectors_json).ok();
126                if let Some(content) = content {
127                    // Try new format first
128                    if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
129                        Some(data)
130                    } else {
131                        // Fallback: Try loading old format
132                        #[derive(Serialize, Deserialize)]
133                        struct OldVectorData {
134                            entries: Vec<OldVectorEntry>,
135                        }
136                        #[derive(Serialize, Deserialize)]
137                        struct OldVectorEntry {
138                            uri: String,
139                            embedding: Vec<f32>,
140                        }
141
142                        if let Ok(old_data) = serde_json::from_str::<OldVectorData>(&content) {
143                            let entries = old_data
144                                .entries
145                                .into_iter()
146                                .map(|old| VectorEntry {
147                                    key: old.uri.clone(),
148                                    embedding: old.embedding,
149                                    metadata: serde_json::json!({ "uri": old.uri }),
150                                })
151                                .collect();
152                            Some(VectorData { entries })
153                        } else {
154                            None
155                        }
156                    }
157                } else {
158                    None
159                }
160            } else {
161                None
162            };
163
164            if let Some(data) = loaded_data {
165                let mut searcher = hnsw::Searcher::default();
166                for entry in data.entries {
167                    if entry.embedding.len() == dimensions {
168                        let id = index.insert(entry.embedding.clone(), &mut searcher);
169                        id_to_key.insert(id, entry.key.clone());
170                        key_to_id.insert(entry.key.clone(), id);
171                        key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
172                        embeddings.push(entry);
173                    }
174                }
175                eprintln!(
176                    "Loaded {} vectors from disk (dim={})",
177                    embeddings.len(),
178                    dimensions
179                );
180            }
181        }
182
183        // Get API token from environment (optional)
184        let api_token = std::env::var("HUGGINGFACE_API_TOKEN").ok();
185
186        // Get API URL from environment (or default)
187        let api_url = std::env::var("HUGGINGFACE_API_URL")
188            .unwrap_or_else(|_| HUGGINGFACE_API_URL.to_string());
189
190        // Configured client with timeout
191        let client = Client::builder()
192            .timeout(std::time::Duration::from_secs(30))
193            .build()
194            .unwrap_or_else(|_| Client::new());
195
196        let local_model = if api_token.is_none() && std::env::var("MOCK_EMBEDDINGS").is_err() {
197            eprintln!("Initializing local embedding model (this may take a moment to download)...");
198            let model = TextEmbedding::try_new(
199                InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_show_download_progress(true),
200            )?;
201            Some(Arc::new(Mutex::new(model)))
202        } else {
203            None
204        };
205
206        Ok(Self {
207            index: Arc::new(RwLock::new(index)),
208            id_to_key: Arc::new(RwLock::new(id_to_key)),
209            key_to_id: Arc::new(RwLock::new(key_to_id)),
210            key_to_metadata: Arc::new(RwLock::new(key_to_metadata)),
211            storage_path,
212            client,
213            api_url,
214            api_token,
215            model: DEFAULT_MODEL.to_string(),
216            dimensions,
217            embeddings: Arc::new(RwLock::new(embeddings)),
218            dirty_count: Arc::new(AtomicUsize::new(0)),
219            auto_save_threshold: DEFAULT_AUTO_SAVE_THRESHOLD,
220            local_model,
221        })
222    }
223
224    /// Save vectors to disk
225    fn save_vectors(&self) -> Result<()> {
226        if let Some(ref path) = self.storage_path {
227            std::fs::create_dir_all(path)?;
228
229            // Hold read lock during serialization AND dirty count reset to avoid race condition.
230            // Wait, we need to read the dirty count at the time of snapshot.
231            // But we can't atomically read-and-subtract without a loop if we don't hold a write lock on dirty_count (which is Atomic).
232            // Actually, if we just subtract the value we *saw* before starting the save, new items added during save will just remain in the counter.
233            // But we are saving the *entire* vector list, so any items added during save (if that were possible with the lock held) would be included?
234            // `embeddings` is protected by RwLock.
235            // We take a read lock on `embeddings` to clone/serialize.
236            // While we hold the read lock, NO new items can be added (add_batch needs write lock).
237            // So the snapshot is consistent.
238            // The race happens *after* we drop the read lock and *before* we reset dirty_count.
239            // During that window, a writer could add items and increment dirty_count.
240            // If we then reset to 0, we lose those counts.
241            //
242            // Fix: Read dirty_count *while holding the read lock*.
243            // Since writers are blocked, dirty_count cannot change while we hold the lock.
244            // So `current_dirty` will be exactly the number of unsaved items included in `entries`.
245            // Then we perform the save (IO).
246            // Finally we subtract `current_dirty`.
247            // If new items come in during IO (after read lock drop), they increment counter.
248            // Subtracting `current_dirty` leaves those new increments intact. Correct.
249
250            let (entries, current_dirty) = {
251                let guard = self.embeddings.read().unwrap();
252                (guard.clone(), self.dirty_count.load(Ordering::Relaxed))
253            };
254
255            let data = VectorData { entries };
256            save_bincode(&path.join("vectors.bin"), &data)?;
257
258            if current_dirty > 0 {
259                let _ = self.dirty_count.fetch_sub(current_dirty, Ordering::Relaxed);
260            }
261        }
262        Ok(())
263    }
264
265    /// Force save to disk
266    pub fn flush(&self) -> Result<()> {
267        self.save_vectors()
268    }
269
270    /// Generate embedding for a text using HuggingFace Inference API
271    /// (Mocked if MOCK_EMBEDDINGS is set)
272    pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
273        if std::env::var("MOCK_EMBEDDINGS").is_ok() {
274            // Return random embedding for testing
275            use rand::Rng;
276            let mut rng = rand::rng();
277            let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
278            return Ok(vec);
279        }
280
281        let embeddings = self.embed_batch(vec![text.to_string()]).await?;
282        Ok(embeddings[0].clone())
283    }
284
285    /// Generate embeddings for multiple texts using HuggingFace Inference API
286    pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
287        if texts.is_empty() {
288            return Ok(Vec::new());
289        }
290
291        if std::env::var("MOCK_EMBEDDINGS").is_ok() {
292            use rand::Rng;
293            let mut rng = rand::rng();
294            let mut results = Vec::new();
295            for _ in 0..texts.len() {
296                let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
297                results.push(vec);
298            }
299            return Ok(results);
300        }
301
302        if let Some(ref model) = self.local_model {
303            let model = model.clone();
304            let texts = texts.clone();
305            let embeddings = tokio::task::spawn_blocking(move || {
306                let mut model = model.lock().unwrap();
307                model.embed(texts, None)
308            })
309            .await??;
310            return Ok(embeddings);
311        }
312
313        let url = format!(
314            "{}/{}/pipeline/feature-extraction",
315            self.api_url, self.model
316        );
317
318        // HuggingFace accepts array of strings for inputs
319        let mut request = self.client.post(&url).json(&serde_json::json!({
320            "inputs": texts,
321        }));
322
323        // Add auth token if available
324        if let Some(ref token) = self.api_token {
325            request = request.header("Authorization", format!("Bearer {}", token));
326        }
327
328        let response = request.send().await?;
329
330        if !response.status().is_success() {
331            let error_text = response.text().await?;
332            anyhow::bail!("HuggingFace API error: {}", error_text);
333        }
334
335        let response_json: serde_json::Value = response.json().await?;
336        let mut results = Vec::new();
337
338        if let Some(arr) = response_json.as_array() {
339            for item in arr {
340                let vec: Vec<f32> = serde_json::from_value(item.clone())
341                    .map_err(|e| anyhow::anyhow!("Failed to parse embedding: {}", e))?;
342
343                if vec.len() != self.dimensions {
344                    anyhow::bail!("Expected {} dimensions, got {}", self.dimensions, vec.len());
345                }
346                results.push(vec);
347            }
348        } else {
349            // Handle case where we sent 1 text and got single flat array [0.1, ...]
350            if texts.len() == 1 {
351                if let Ok(vec) = serde_json::from_value::<Vec<f32>>(response_json) {
352                    if vec.len() == self.dimensions {
353                        results.push(vec);
354                    }
355                }
356            }
357        }
358
359        if results.len() != texts.len() {
360            anyhow::bail!("Expected {} embeddings, got {}", texts.len(), results.len());
361        }
362
363        Ok(results)
364    }
365
366    /// Add a key with its text content to the index
367    pub async fn add(
368        &self,
369        key: &str,
370        content: &str,
371        metadata: serde_json::Value,
372    ) -> Result<usize> {
373        let results = self
374            .add_batch(vec![(key.to_string(), content.to_string(), metadata)])
375            .await?;
376        Ok(results[0])
377    }
378
379    /// Add multiple keys with their text content to the index
380    pub async fn add_batch(
381        &self,
382        items: Vec<(String, String, serde_json::Value)>,
383    ) -> Result<Vec<usize>> {
384        // Filter out existing keys
385        let mut new_items = Vec::new();
386        let mut result_ids = vec![0; items.len()];
387        let mut new_indices = Vec::new(); // Map index in new_items to index in items
388
389        {
390            let key_map = self.key_to_id.read().unwrap();
391            for (i, (key, content, _)) in items.iter().enumerate() {
392                if let Some(&id) = key_map.get(key) {
393                    result_ids[i] = id;
394                } else {
395                    new_items.push(content.clone());
396                    new_indices.push(i);
397                }
398            }
399        }
400
401        if new_items.is_empty() {
402            return Ok(result_ids);
403        }
404
405        // Generate embeddings via HuggingFace API
406        let embeddings = self.embed_batch(new_items).await?;
407
408        let mut ids_to_add = Vec::new();
409        let mut searcher = hnsw::Searcher::default();
410
411        // Add to HNSW index and maps
412        {
413            let mut index = self.index.write().unwrap();
414            let mut key_map = self.key_to_id.write().unwrap();
415            let mut id_map = self.id_to_key.write().unwrap();
416            let mut metadata_map = self.key_to_metadata.write().unwrap();
417            let mut embs = self.embeddings.write().unwrap();
418
419            for (i, embedding) in embeddings.into_iter().enumerate() {
420                let original_idx = new_indices[i];
421                let (key, _, metadata) = &items[original_idx];
422
423                // Double check if inserted in race condition
424                if let Some(&id) = key_map.get(key) {
425                    result_ids[original_idx] = id;
426                    continue;
427                }
428
429                let id = index.insert(embedding.clone(), &mut searcher);
430                key_map.insert(key.clone(), id);
431                id_map.insert(id, key.clone());
432                metadata_map.insert(key.clone(), metadata.clone());
433
434                embs.push(VectorEntry {
435                    key: key.clone(),
436                    embedding,
437                    metadata: metadata.clone(),
438                });
439
440                result_ids[original_idx] = id;
441                ids_to_add.push(id);
442            }
443        }
444
445        if !ids_to_add.is_empty() {
446            let count = self
447                .dirty_count
448                .fetch_add(ids_to_add.len(), Ordering::Relaxed);
449            if count + ids_to_add.len() >= self.auto_save_threshold {
450                let _ = self.save_vectors();
451            }
452        }
453
454        Ok(result_ids)
455    }
456
457    /// Search for similar vectors
458    pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
459        // Generate query embedding via HuggingFace API
460        let query_embedding = self.embed(query).await?;
461
462        // Search HNSW index
463        let mut searcher = hnsw::Searcher::default();
464
465        let index = self.index.read().unwrap();
466        let len = index.len();
467        if len == 0 {
468            return Ok(Vec::new());
469        }
470
471        let k = k.min(len);
472        let ef = k.max(50); // Ensure ef is at least k, but usually 50+
473
474        let mut neighbors = vec![
475            space::Neighbor {
476                index: 0,
477                distance: u32::MAX
478            };
479            k
480        ];
481
482        let found_neighbors = index.nearest(&query_embedding, ef, &mut searcher, &mut neighbors);
483
484        // Convert to results
485        let id_map = self.id_to_key.read().unwrap();
486        let metadata_map = self.key_to_metadata.read().unwrap();
487
488        let results: Vec<SearchResult> = found_neighbors
489            .iter()
490            .filter_map(|neighbor| {
491                id_map.get(&neighbor.index).map(|key| {
492                    let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
493
494                    let metadata = metadata_map
495                        .get(key)
496                        .cloned()
497                        .unwrap_or(serde_json::Value::Null);
498
499                    let uri = metadata
500                        .get("uri")
501                        .and_then(|v| v.as_str())
502                        .unwrap_or(key)
503                        .to_string();
504
505                    SearchResult {
506                        key: key.clone(),
507                        score: 1.0 / (1.0 + score_f32),
508                        metadata,
509                        uri,
510                    }
511                })
512            })
513            .collect();
514
515        Ok(results)
516    }
517
518    pub fn get_key(&self, id: usize) -> Option<String> {
519        self.id_to_key.read().unwrap().get(&id).cloned()
520    }
521
522    pub fn get_id(&self, key: &str) -> Option<usize> {
523        self.key_to_id.read().unwrap().get(key).copied()
524    }
525
526    pub fn len(&self) -> usize {
527        self.key_to_id.read().unwrap().len()
528    }
529
530    pub fn is_empty(&self) -> bool {
531        self.len() == 0
532    }
533
534    /// Compaction: rebuild index from stored embeddings, removing stale entries
535    pub fn compact(&self) -> Result<usize> {
536        let embeddings = self.embeddings.read().unwrap();
537        let current_keys: std::collections::HashSet<_> =
538            self.key_to_id.read().unwrap().keys().cloned().collect();
539
540        if current_keys.is_empty() && !embeddings.is_empty() {
541            return Ok(0);
542        }
543
544        // Filter to only current Keys
545        let active_entries: Vec<_> = embeddings
546            .iter()
547            .filter(|e| current_keys.contains(&e.key))
548            .cloned()
549            .collect();
550
551        let removed = embeddings.len() - active_entries.len();
552
553        if removed == 0 {
554            return Ok(0);
555        }
556
557        // Rebuild index
558        let mut new_index = hnsw::Hnsw::new(Euclidian);
559        let mut new_id_to_key = std::collections::HashMap::new();
560        let mut new_key_to_id = std::collections::HashMap::new();
561        let mut new_key_to_metadata = std::collections::HashMap::new();
562        let mut searcher = hnsw::Searcher::default();
563
564        for entry in &active_entries {
565            if entry.embedding.len() == self.dimensions {
566                let id = new_index.insert(entry.embedding.clone(), &mut searcher);
567                new_id_to_key.insert(id, entry.key.clone());
568                new_key_to_id.insert(entry.key.clone(), id);
569                new_key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
570            }
571        }
572
573        // Swap in new index
574        *self.index.write().unwrap() = new_index;
575        *self.id_to_key.write().unwrap() = new_id_to_key;
576        *self.key_to_id.write().unwrap() = new_key_to_id;
577        *self.key_to_metadata.write().unwrap() = new_key_to_metadata;
578
579        // Update embeddings (drop takes write lock)
580        drop(embeddings);
581        *self.embeddings.write().unwrap() = active_entries;
582
583        let _ = self.save_vectors();
584
585        Ok(removed)
586    }
587
588    /// Remove a Key from the vector store
589    pub fn remove(&self, key: &str) -> bool {
590        let mut key_map = self.key_to_id.write().unwrap();
591        let mut id_map = self.id_to_key.write().unwrap();
592        let mut metadata_map = self.key_to_metadata.write().unwrap();
593
594        if let Some(id) = key_map.remove(key) {
595            id_map.remove(&id);
596            metadata_map.remove(key);
597            // Note: actual index entry remains until compaction
598            true
599        } else {
600            false
601        }
602    }
603
604    /// Get storage stats
605    pub fn stats(&self) -> (usize, usize, usize) {
606        let embeddings_count = self.embeddings.read().unwrap().len();
607        let active_count = self.key_to_id.read().unwrap().len();
608        let stale_count = embeddings_count.saturating_sub(active_count);
609        (active_count, stale_count, embeddings_count)
610    }
611}