Skip to main content

synapse_core/
vector_store.rs

1use anyhow::Result;
2use hnsw::Hnsw;
3use rand_pcg::Pcg64;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::{Arc, RwLock};
9
10const HUGGINGFACE_API_URL: &str = "https://router.huggingface.co/hf-inference/models";
11const DEFAULT_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2"; // 384 dims, fast
12const DEFAULT_DIMENSIONS: usize = 384;
13
14/// Euclidean distance metric for HNSW
15#[derive(Default, Clone)]
16pub struct Euclidian;
17
18impl space::Metric<Vec<f32>> for Euclidian {
19    type Unit = u32;
20    fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
21        let len = a.len().min(b.len());
22        let mut dist_sq = 0.0;
23        for i in 0..len {
24            let diff = a[i] - b[i];
25            dist_sq += diff * diff;
26        }
27        // Use fixed-point arithmetic to avoid bitwise comparison issues
28        (dist_sq.sqrt() * 1_000_000.0) as u32
29    }
30}
31
32/// Persisted vector data
33#[derive(Serialize, Deserialize, Default)]
34struct VectorData {
35    entries: Vec<VectorEntry>,
36}
37
38#[derive(Serialize, Deserialize, Clone)]
39struct VectorEntry {
40    /// Unique identifier for this vector (could be URI or Hash)
41    key: String,
42    embedding: Vec<f32>,
43    /// Optional metadata associated with the vector
44    #[serde(default)]
45    metadata: serde_json::Value,
46}
47
48/// Vector store using HuggingFace Inference API for embeddings
49pub struct VectorStore {
50    /// HNSW index for fast approximate nearest neighbor search
51    index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
52    /// Mapping from node ID (internal) to Key
53    id_to_key: Arc<RwLock<HashMap<usize, String>>>,
54    /// Mapping from Key to node ID (internal)
55    key_to_id: Arc<RwLock<HashMap<String, usize>>>,
56    /// Mapping from Key to Metadata (for fast retrieval)
57    key_to_metadata: Arc<RwLock<HashMap<String, serde_json::Value>>>,
58    /// Storage path for persistence
59    storage_path: Option<PathBuf>,
60    /// HTTP client for HuggingFace API
61    client: Client,
62    /// HuggingFace API token (optional, for rate limits)
63    api_token: Option<String>,
64    /// Model name
65    model: String,
66    /// Vector dimensions
67    dimensions: usize,
68    /// Stored embeddings for persistence
69    embeddings: Arc<RwLock<Vec<VectorEntry>>>,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73pub struct SearchResult {
74    /// The unique key
75    pub key: String,
76    pub score: f32,
77    /// Metadata (including original URI if applicable)
78    pub metadata: serde_json::Value,
79    // Helper to access URI from metadata if it exists, for backward compatibility
80    pub uri: String,
81}
82
83impl VectorStore {
84    /// Create a new vector store for a namespace
85    pub fn new(namespace: &str) -> Result<Self> {
86        // Try to get storage path from environment
87        let storage_path = std::env::var("GRAPH_STORAGE_PATH")
88            .ok()
89            .map(|p| PathBuf::from(p).join(namespace));
90
91        // Get dimensions from env or default
92        let dimensions = std::env::var("VECTOR_DIMENSIONS")
93            .ok()
94            .and_then(|s| s.parse().ok())
95            .unwrap_or(DEFAULT_DIMENSIONS);
96        
97        // Create HNSW index with Euclidian metric
98        let mut index = Hnsw::new(Euclidian);
99        let mut id_to_key = HashMap::new();
100        let mut key_to_id = HashMap::new();
101        let mut key_to_metadata = HashMap::new();
102        let mut embeddings = Vec::new();
103
104        // Try to load persisted vectors
105        if let Some(ref path) = storage_path {
106            let vectors_path = path.join("vectors.json");
107            if vectors_path.exists() {
108                if let Ok(content) = std::fs::read_to_string(&vectors_path) {
109                    // Try new format first
110                    if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
111                        let mut searcher = hnsw::Searcher::default();
112                        for entry in data.entries {
113                            if entry.embedding.len() == dimensions {
114                                let id = index.insert(entry.embedding.clone(), &mut searcher);
115                                id_to_key.insert(id, entry.key.clone());
116                                key_to_id.insert(entry.key.clone(), id);
117                                key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
118                                embeddings.push(entry);
119                            }
120                        }
121                        eprintln!("Loaded {} vectors from disk (dim={})", embeddings.len(), dimensions);
122                    } else {
123                        // Fallback: Try loading old format (VectorEntry with 'uri' instead of 'key')
124                        #[derive(Serialize, Deserialize)]
125                        struct OldVectorData { entries: Vec<OldVectorEntry> }
126                        #[derive(Serialize, Deserialize)]
127                        struct OldVectorEntry { uri: String, embedding: Vec<f32> }
128
129                        if let Ok(old_data) = serde_json::from_str::<OldVectorData>(&content) {
130                             let mut searcher = hnsw::Searcher::default();
131                             for old in old_data.entries {
132                                 if old.embedding.len() == dimensions {
133                                     let id = index.insert(old.embedding.clone(), &mut searcher);
134                                     id_to_key.insert(id, old.uri.clone());
135                                     key_to_id.insert(old.uri.clone(), id);
136                                     let metadata = serde_json::json!({ "uri": old.uri });
137                                     key_to_metadata.insert(old.uri.clone(), metadata.clone());
138                                     embeddings.push(VectorEntry {
139                                         key: old.uri.clone(),
140                                         embedding: old.embedding,
141                                         metadata,
142                                     });
143                                 }
144                             }
145                             eprintln!("Loaded {} legacy vectors from disk (dim={})", embeddings.len(), dimensions);
146                        }
147                    }
148                }
149            }
150        }
151
152        // Get API token from environment (optional)
153        let api_token = std::env::var("HUGGINGFACE_API_TOKEN").ok();
154
155        // Configured client with timeout
156        let client = Client::builder()
157            .timeout(std::time::Duration::from_secs(30))
158            .build()
159            .unwrap_or_else(|_| Client::new());
160
161        Ok(Self {
162            index: Arc::new(RwLock::new(index)),
163            id_to_key: Arc::new(RwLock::new(id_to_key)),
164            key_to_id: Arc::new(RwLock::new(key_to_id)),
165            key_to_metadata: Arc::new(RwLock::new(key_to_metadata)),
166            storage_path,
167            client,
168            api_token,
169            model: DEFAULT_MODEL.to_string(),
170            dimensions,
171            embeddings: Arc::new(RwLock::new(embeddings)),
172        })
173    }
174
175    /// Save vectors to disk
176    fn save_vectors(&self) -> Result<()> {
177        if let Some(ref path) = self.storage_path {
178            std::fs::create_dir_all(path)?;
179            let data = VectorData {
180                entries: self.embeddings.read().unwrap().clone(),
181            };
182            let content = serde_json::to_string(&data)?;
183            std::fs::write(path.join("vectors.json"), content)?;
184        }
185        Ok(())
186    }
187
188    /// Generate embedding for a text using HuggingFace Inference API
189    /// (Mocked if MOCK_EMBEDDINGS is set)
190    pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
191        if std::env::var("MOCK_EMBEDDINGS").is_ok() {
192            // Return random embedding for testing
193            use rand::Rng;
194            let mut rng = rand::rng();
195            let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
196            return Ok(vec);
197        }
198
199        let embeddings = self.embed_batch(vec![text.to_string()]).await?;
200        Ok(embeddings[0].clone())
201    }
202
203    /// Generate embeddings for multiple texts using HuggingFace Inference API
204    pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
205        if texts.is_empty() {
206            return Ok(Vec::new());
207        }
208
209        if std::env::var("MOCK_EMBEDDINGS").is_ok() {
210             use rand::Rng;
211             let mut rng = rand::rng();
212             let mut results = Vec::new();
213             for _ in 0..texts.len() {
214                 let vec: Vec<f32> = (0..self.dimensions).map(|_| rng.random()).collect();
215                 results.push(vec);
216             }
217             return Ok(results);
218        }
219
220        let url = format!(
221            "{}/{}/pipeline/feature-extraction",
222            HUGGINGFACE_API_URL, self.model
223        );
224
225        // HuggingFace accepts array of strings for inputs
226        let mut request = self.client.post(&url).json(&serde_json::json!({
227            "inputs": texts,
228        }));
229
230        // Add auth token if available
231        if let Some(ref token) = self.api_token {
232            request = request.header("Authorization", format!("Bearer {}", token));
233        }
234
235        let response = request.send().await?;
236
237        if !response.status().is_success() {
238            let error_text = response.text().await?;
239            anyhow::bail!("HuggingFace API error: {}", error_text);
240        }
241
242        let response_json: serde_json::Value = response.json().await?;
243        let mut results = Vec::new();
244
245        if let Some(arr) = response_json.as_array() {
246            for item in arr {
247                let vec: Vec<f32> = serde_json::from_value(item.clone())
248                    .map_err(|e| anyhow::anyhow!("Failed to parse embedding: {}", e))?;
249
250                if vec.len() != self.dimensions {
251                    anyhow::bail!("Expected {} dimensions, got {}", self.dimensions, vec.len());
252                }
253                results.push(vec);
254            }
255        } else {
256            // Handle case where we sent 1 text and got single flat array [0.1, ...]
257            if texts.len() == 1 {
258                if let Ok(vec) = serde_json::from_value::<Vec<f32>>(response_json) {
259                    if vec.len() == self.dimensions {
260                        results.push(vec);
261                    }
262                }
263            }
264        }
265
266        if results.len() != texts.len() {
267            anyhow::bail!("Expected {} embeddings, got {}", texts.len(), results.len());
268        }
269
270        Ok(results)
271    }
272
273    /// Add a key with its text content to the index
274    pub async fn add(&self, key: &str, content: &str, metadata: serde_json::Value) -> Result<usize> {
275        let results = self.add_batch(vec![(key.to_string(), content.to_string(), metadata)]).await?;
276        Ok(results[0])
277    }
278
279    /// Add multiple keys with their text content to the index
280    pub async fn add_batch(&self, items: Vec<(String, String, serde_json::Value)>) -> Result<Vec<usize>> {
281        // Filter out existing keys
282        let mut new_items = Vec::new();
283        let mut result_ids = vec![0; items.len()];
284        let mut new_indices = Vec::new(); // Map index in new_items to index in items
285
286        {
287            let key_map = self.key_to_id.read().unwrap();
288            for (i, (key, content, _)) in items.iter().enumerate() {
289                if let Some(&id) = key_map.get(key) {
290                    result_ids[i] = id;
291                } else {
292                    new_items.push(content.clone());
293                    new_indices.push(i);
294                }
295            }
296        }
297
298        if new_items.is_empty() {
299            return Ok(result_ids);
300        }
301
302        // Generate embeddings via HuggingFace API
303        let embeddings = self.embed_batch(new_items).await?;
304
305        let mut ids_to_add = Vec::new();
306        let mut searcher = hnsw::Searcher::default();
307
308        // Add to HNSW index and maps
309        {
310            let mut index = self.index.write().unwrap();
311            let mut key_map = self.key_to_id.write().unwrap();
312            let mut id_map = self.id_to_key.write().unwrap();
313            let mut metadata_map = self.key_to_metadata.write().unwrap();
314            let mut embs = self.embeddings.write().unwrap();
315
316            for (i, embedding) in embeddings.into_iter().enumerate() {
317                let original_idx = new_indices[i];
318                let (key, _, metadata) = &items[original_idx];
319
320                // Double check if inserted in race condition
321                if let Some(&id) = key_map.get(key) {
322                    result_ids[original_idx] = id;
323                    continue;
324                }
325
326                let id = index.insert(embedding.clone(), &mut searcher);
327                key_map.insert(key.clone(), id);
328                id_map.insert(id, key.clone());
329                metadata_map.insert(key.clone(), metadata.clone());
330
331                embs.push(VectorEntry {
332                    key: key.clone(),
333                    embedding: embedding,
334                    metadata: metadata.clone(),
335                });
336
337                result_ids[original_idx] = id;
338                ids_to_add.push(id);
339            }
340        }
341
342        if !ids_to_add.is_empty() {
343            let _ = self.save_vectors(); // Best effort persistence
344        }
345
346        Ok(result_ids)
347    }
348
349    /// Search for similar vectors
350    pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
351        // Generate query embedding via HuggingFace API
352        let query_embedding = self.embed(query).await?;
353
354        // Search HNSW index
355        let mut searcher = hnsw::Searcher::default();
356        let mut neighbors = vec![
357            space::Neighbor {
358                index: 0,
359                distance: u32::MAX
360            };
361            k
362        ];
363
364        let found_neighbors = {
365            let index = self.index.read().unwrap();
366            index.nearest(&query_embedding, 50, &mut searcher, &mut neighbors)
367        };
368
369        // Convert to results
370        let id_map = self.id_to_key.read().unwrap();
371        let metadata_map = self.key_to_metadata.read().unwrap();
372
373        let results: Vec<SearchResult> = found_neighbors
374            .iter()
375            .filter_map(|neighbor| {
376                id_map.get(&neighbor.index).map(|key| {
377                    let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
378
379                    let metadata = metadata_map.get(key).cloned().unwrap_or(serde_json::Value::Null);
380
381                    let uri = metadata.get("uri").and_then(|v| v.as_str()).unwrap_or(key).to_string();
382
383                    SearchResult {
384                        key: key.clone(),
385                        score: 1.0 / (1.0 + score_f32),
386                        metadata,
387                        uri,
388                    }
389                })
390            })
391            .collect();
392
393        Ok(results)
394    }
395
396    pub fn get_key(&self, id: usize) -> Option<String> {
397        self.id_to_key.read().unwrap().get(&id).cloned()
398    }
399
400    pub fn get_id(&self, key: &str) -> Option<usize> {
401        self.key_to_id.read().unwrap().get(key).copied()
402    }
403
404    pub fn len(&self) -> usize {
405        self.key_to_id.read().unwrap().len()
406    }
407
408    pub fn is_empty(&self) -> bool {
409        self.len() == 0
410    }
411
412    /// Compaction: rebuild index from stored embeddings, removing stale entries
413    pub fn compact(&self) -> Result<usize> {
414        let embeddings = self.embeddings.read().unwrap();
415        let current_keys: std::collections::HashSet<_> = self.key_to_id.read().unwrap().keys().cloned().collect();
416        
417        if current_keys.is_empty() && !embeddings.is_empty() {
418            return Ok(0);
419        }
420
421        // Filter to only current Keys
422        let active_entries: Vec<_> = embeddings
423            .iter()
424            .filter(|e| current_keys.contains(&e.key))
425            .cloned()
426            .collect();
427
428        let removed = embeddings.len() - active_entries.len();
429
430        if removed == 0 {
431            return Ok(0);
432        }
433
434        // Rebuild index
435        let mut new_index = hnsw::Hnsw::new(Euclidian);
436        let mut new_id_to_key = std::collections::HashMap::new();
437        let mut new_key_to_id = std::collections::HashMap::new();
438        let mut new_key_to_metadata = std::collections::HashMap::new();
439        let mut searcher = hnsw::Searcher::default();
440
441        for entry in &active_entries {
442            if entry.embedding.len() == self.dimensions {
443                let id = new_index.insert(entry.embedding.clone(), &mut searcher);
444                new_id_to_key.insert(id, entry.key.clone());
445                new_key_to_id.insert(entry.key.clone(), id);
446                new_key_to_metadata.insert(entry.key.clone(), entry.metadata.clone());
447            }
448        }
449
450        // Swap in new index
451        *self.index.write().unwrap() = new_index;
452        *self.id_to_key.write().unwrap() = new_id_to_key;
453        *self.key_to_id.write().unwrap() = new_key_to_id;
454        *self.key_to_metadata.write().unwrap() = new_key_to_metadata;
455
456        // Update embeddings (drop takes write lock)
457        drop(embeddings);
458        *self.embeddings.write().unwrap() = active_entries;
459
460        let _ = self.save_vectors();
461
462        Ok(removed)
463    }
464
465    /// Remove a Key from the vector store
466    pub fn remove(&self, key: &str) -> bool {
467        let mut key_map = self.key_to_id.write().unwrap();
468        let mut id_map = self.id_to_key.write().unwrap();
469        let mut metadata_map = self.key_to_metadata.write().unwrap();
470
471        if let Some(id) = key_map.remove(key) {
472            id_map.remove(&id);
473            metadata_map.remove(key);
474            // Note: actual index entry remains until compaction
475            true
476        } else {
477            false
478        }
479    }
480
481    /// Get storage stats
482    pub fn stats(&self) -> (usize, usize, usize) {
483        let embeddings_count = self.embeddings.read().unwrap().len();
484        let active_count = self.key_to_id.read().unwrap().len();
485        let stale_count = embeddings_count.saturating_sub(active_count);
486        (active_count, stale_count, embeddings_count)
487    }
488}