Skip to main content

synapse_core/
vector_store.rs

1use anyhow::Result;
2use hnsw::Hnsw;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::{Arc, RwLock};
7use std::path::PathBuf;
8use rand_pcg::Pcg64;
9
10const HUGGINGFACE_API_URL: &str = "https://router.huggingface.co/hf-inference/models";
11const DEFAULT_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2"; // 384 dims, fast
12
13/// Euclidean distance metric for HNSW
14#[derive(Default, Clone)]
15pub struct Euclidian;
16
17impl space::Metric<[f32; 384]> for Euclidian {
18    type Unit = u64;
19    fn distance(&self, a: &[f32; 384], b: &[f32; 384]) -> u64 {
20        let mut dist_sq = 0.0;
21        for i in 0..384 {
22            let diff = a[i] - b[i];
23            dist_sq += diff * diff;
24        }
25        // Floating point to bits for ordered comparison as per space v0.17 recommendations
26        dist_sq.sqrt().to_bits() as u64
27    }
28}
29
30/// Persisted vector data
31#[derive(Serialize, Deserialize, Default)]
32struct VectorData {
33    entries: Vec<VectorEntry>,
34}
35
36#[derive(Serialize, Deserialize, Clone)]
37struct VectorEntry {
38    uri: String,
39    embedding: Vec<f32>,
40}
41
42/// Vector store using HuggingFace Inference API for embeddings
43pub struct VectorStore {
44    /// HNSW index for fast approximate nearest neighbor search
45    index: Arc<RwLock<Hnsw<Euclidian, [f32; 384], Pcg64, 16, 32>>>,
46    /// Mapping from node ID to URI
47    id_to_uri: Arc<RwLock<HashMap<usize, String>>>,
48    /// Mapping from URI to node ID
49    uri_to_id: Arc<RwLock<HashMap<String, usize>>>,
50    /// Storage path for persistence
51    storage_path: Option<PathBuf>,
52    /// HTTP client for HuggingFace API
53    client: Client,
54    /// HuggingFace API token (optional, for rate limits)
55    api_token: Option<String>,
56    /// Model name
57    model: String,
58    /// Stored embeddings for persistence
59    embeddings: Arc<RwLock<Vec<VectorEntry>>>,
60}
61
62#[derive(Debug, Serialize, Deserialize)]
63pub struct SearchResult {
64    pub uri: String,
65    pub score: f32,
66    pub content: String,
67}
68
69#[derive(Serialize)]
70struct EmbeddingRequest {
71    inputs: String,
72}
73
74impl VectorStore {
75    /// Create a new vector store for a namespace
76    pub fn new(namespace: &str) -> Result<Self> {
77        // Try to get storage path from environment
78        let storage_path = std::env::var("GRAPH_STORAGE_PATH")
79            .ok()
80            .map(|p| PathBuf::from(p).join(namespace));
81        
82        // Create HNSW index with Euclidian metric
83        let mut index = Hnsw::new(Euclidian);
84        let mut id_to_uri = HashMap::new();
85        let mut uri_to_id = HashMap::new();
86        let mut embeddings = Vec::new();
87
88        // Try to load persisted vectors
89        if let Some(ref path) = storage_path {
90            let vectors_path = path.join("vectors.json");
91            if vectors_path.exists() {
92                if let Ok(content) = std::fs::read_to_string(&vectors_path) {
93                    if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
94                        let mut searcher = hnsw::Searcher::default();
95                        for entry in data.entries {
96                            if entry.embedding.len() == 384 {
97                                let mut emb = [0.0f32; 384];
98                                emb.copy_from_slice(&entry.embedding);
99                                let id = index.insert(emb, &mut searcher);
100                                id_to_uri.insert(id, entry.uri.clone());
101                                uri_to_id.insert(entry.uri.clone(), id);
102                                embeddings.push(entry);
103                            }
104                        }
105                        eprintln!("Loaded {} vectors from disk", embeddings.len());
106                    }
107                }
108            }
109        }
110
111        // Get API token from environment (optional)
112        let api_token = std::env::var("HUGGINGFACE_API_TOKEN").ok();
113
114        Ok(Self {
115            index: Arc::new(RwLock::new(index)),
116            id_to_uri: Arc::new(RwLock::new(id_to_uri)),
117            uri_to_id: Arc::new(RwLock::new(uri_to_id)),
118            storage_path,
119            client: Client::new(),
120            api_token,
121            model: DEFAULT_MODEL.to_string(),
122            embeddings: Arc::new(RwLock::new(embeddings)),
123        })
124    }
125
126    /// Save vectors to disk
127    fn save_vectors(&self) -> Result<()> {
128        if let Some(ref path) = self.storage_path {
129            std::fs::create_dir_all(path)?;
130            let data = VectorData {
131                entries: self.embeddings.read().unwrap().clone(),
132            };
133            let content = serde_json::to_string(&data)?;
134            std::fs::write(path.join("vectors.json"), content)?;
135        }
136        Ok(())
137    }
138
139    /// Generate embedding for a text using HuggingFace Inference API
140    pub async fn embed(&self, text: &str) -> Result<[f32; 384]> {
141        let url = format!("{}/{}/pipeline/feature-extraction", HUGGINGFACE_API_URL, self.model);
142        
143        let mut request = self.client
144            .post(&url)
145            .json(&EmbeddingRequest {
146                inputs: text.to_string(),
147            });
148
149        // Add auth token if available
150        if let Some(ref token) = self.api_token {
151            request = request.header("Authorization", format!("Bearer {}", token));
152        }
153
154        let response = request.send().await?;
155        
156        if !response.status().is_success() {
157            let error_text = response.text().await?;
158            anyhow::bail!("HuggingFace API error: {}", error_text);
159        }
160
161        // Response is a Vec<f32> directly
162        let embedding_vec: Vec<f32> = response.json().await?;
163        
164        if embedding_vec.len() != 384 {
165            anyhow::bail!("Expected 384 dimensions, got {}", embedding_vec.len());
166        }
167
168        let mut embedding = [0.0f32; 384];
169        embedding.copy_from_slice(&embedding_vec[0..384]);
170        
171        Ok(embedding)
172    }
173
174    /// Add a URI with its text content to the index
175    pub async fn add(&self, uri: &str, content: &str) -> Result<usize> {
176        // Check if URI already exists
177        {
178            let uri_map = self.uri_to_id.read().unwrap();
179            if let Some(&id) = uri_map.get(uri) {
180                return Ok(id);
181            }
182        }
183
184        // Generate embedding via HuggingFace API
185        let embedding = self.embed(content).await?;
186        
187        // Add to HNSW index
188        let mut searcher = hnsw::Searcher::default();
189        let id = {
190            let mut index = self.index.write().unwrap();
191            index.insert(embedding, &mut searcher)
192        };
193
194        // Update mappings
195        {
196            let mut uri_map = self.uri_to_id.write().unwrap();
197            let mut id_map = self.id_to_uri.write().unwrap();
198            uri_map.insert(uri.to_string(), id);
199            id_map.insert(id, uri.to_string());
200        }
201
202        // Persist embedding for recovery
203        {
204            let mut embs = self.embeddings.write().unwrap();
205            embs.push(VectorEntry {
206                uri: uri.to_string(),
207                embedding: embedding.to_vec(),
208            });
209        }
210        let _ = self.save_vectors(); // Best effort persistence
211
212        Ok(id)
213    }
214
215    /// Search for similar vectors
216    pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
217        // Generate query embedding via HuggingFace API
218        let query_embedding = self.embed(query).await?;
219
220        // Search HNSW index
221        let mut searcher = hnsw::Searcher::default();
222        let mut neighbors = vec![space::Neighbor { index: 0, distance: !0 }; k];
223        
224        let found_neighbors = {
225            let index = self.index.read().unwrap();
226            index.nearest(&query_embedding, 50, &mut searcher, &mut neighbors)
227        };
228
229        // Convert to results
230        let id_map = self.id_to_uri.read().unwrap();
231        let results: Vec<SearchResult> = found_neighbors
232            .iter()
233            .filter_map(|neighbor| {
234                id_map.get(&neighbor.index).map(|uri| {
235                    // Convert back from bits to f32
236                    let score_f32 = f32::from_bits(neighbor.distance as u32);
237                    SearchResult {
238                        uri: uri.clone(),
239                        score: 1.0 / (1.0 + score_f32), 
240                        content: uri.clone(), 
241                    }
242                })
243            })
244            .collect();
245
246        Ok(results)
247    }
248
249    pub fn get_uri(&self, id: usize) -> Option<String> {
250        self.id_to_uri.read().unwrap().get(&id).cloned()
251    }
252
253    pub fn get_id(&self, uri: &str) -> Option<usize> {
254        self.uri_to_id.read().unwrap().get(uri).copied()
255    }
256
257    pub fn len(&self) -> usize {
258        self.uri_to_id.read().unwrap().len()
259    }
260
261    pub fn is_empty(&self) -> bool {
262        self.len() == 0
263    }
264
265    /// Compaction: rebuild index from stored embeddings, removing stale entries
266    pub fn compact(&self) -> Result<usize> {
267        let embeddings = self.embeddings.read().unwrap();
268        let current_uris: std::collections::HashSet<_> = self.uri_to_id.read().unwrap().keys().cloned().collect();
269        
270        // Filter to only current URIs
271        let active_entries: Vec<_> = embeddings.iter()
272            .filter(|e| current_uris.contains(&e.uri))
273            .cloned()
274            .collect();
275
276        let removed = embeddings.len() - active_entries.len();
277        
278        if removed == 0 {
279            return Ok(0);
280        }
281
282        // Rebuild index
283        let mut new_index = hnsw::Hnsw::new(Euclidian);
284        let mut new_id_to_uri = std::collections::HashMap::new();
285        let mut new_uri_to_id = std::collections::HashMap::new();
286        let mut searcher = hnsw::Searcher::default();
287
288        for entry in &active_entries {
289            if entry.embedding.len() == 384 {
290                let mut emb = [0.0f32; 384];
291                emb.copy_from_slice(&entry.embedding);
292                let id = new_index.insert(emb, &mut searcher);
293                new_id_to_uri.insert(id, entry.uri.clone());
294                new_uri_to_id.insert(entry.uri.clone(), id);
295            }
296        }
297
298        // Swap in new index
299        *self.index.write().unwrap() = new_index;
300        *self.id_to_uri.write().unwrap() = new_id_to_uri;
301        *self.uri_to_id.write().unwrap() = new_uri_to_id;
302        
303        // Update embeddings (drop takes write lock)
304        drop(embeddings);
305        *self.embeddings.write().unwrap() = active_entries;
306        
307        let _ = self.save_vectors();
308        
309        Ok(removed)
310    }
311
312    /// Remove a URI from the vector store
313    pub fn remove(&self, uri: &str) -> bool {
314        let mut uri_map = self.uri_to_id.write().unwrap();
315        let mut id_map = self.id_to_uri.write().unwrap();
316
317        if let Some(id) = uri_map.remove(uri) {
318            id_map.remove(&id);
319            // Note: actual index entry remains until compaction
320            true
321        } else {
322            false
323        }
324    }
325
326    /// Get storage stats
327    pub fn stats(&self) -> (usize, usize, usize) {
328        let embeddings_count = self.embeddings.read().unwrap().len();
329        let active_count = self.uri_to_id.read().unwrap().len();
330        let stale_count = embeddings_count.saturating_sub(active_count);
331        (active_count, stale_count, embeddings_count)
332    }
333}