Skip to main content

synapse_core/
vector_store.rs

1use anyhow::Result;
2use hnsw::Hnsw;
3use rand_pcg::Pcg64;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::{Arc, RwLock};
9
10const HUGGINGFACE_API_URL: &str = "https://router.huggingface.co/hf-inference/models";
11const DEFAULT_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2"; // 384 dims, fast
12const DEFAULT_DIMENSIONS: usize = 384;
13
14/// Euclidean distance metric for HNSW
15#[derive(Default, Clone)]
16pub struct Euclidian;
17
18impl space::Metric<Vec<f32>> for Euclidian {
19    type Unit = u32;
20    fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
21        let len = a.len().min(b.len());
22        let mut dist_sq = 0.0;
23        for i in 0..len {
24            let diff = a[i] - b[i];
25            dist_sq += diff * diff;
26        }
27        // Use fixed-point arithmetic to avoid bitwise comparison issues
28        (dist_sq.sqrt() * 1_000_000.0) as u32
29    }
30}
31
32/// Persisted vector data
33#[derive(Serialize, Deserialize, Default)]
34struct VectorData {
35    entries: Vec<VectorEntry>,
36}
37
38#[derive(Serialize, Deserialize, Clone)]
39struct VectorEntry {
40    uri: String,
41    embedding: Vec<f32>,
42}
43
44/// Vector store using HuggingFace Inference API for embeddings
45pub struct VectorStore {
46    /// HNSW index for fast approximate nearest neighbor search
47    index: Arc<RwLock<Hnsw<Euclidian, Vec<f32>, Pcg64, 16, 32>>>,
48    /// Mapping from node ID to URI
49    id_to_uri: Arc<RwLock<HashMap<usize, String>>>,
50    /// Mapping from URI to node ID
51    uri_to_id: Arc<RwLock<HashMap<String, usize>>>,
52    /// Storage path for persistence
53    storage_path: Option<PathBuf>,
54    /// HTTP client for HuggingFace API
55    client: Client,
56    /// HuggingFace API token (optional, for rate limits)
57    api_token: Option<String>,
58    /// Model name
59    model: String,
60    /// Vector dimensions
61    dimensions: usize,
62    /// Stored embeddings for persistence
63    embeddings: Arc<RwLock<Vec<VectorEntry>>>,
64}
65
66#[derive(Debug, Serialize, Deserialize)]
67pub struct SearchResult {
68    pub uri: String,
69    pub score: f32,
70    pub content: String,
71}
72
73impl VectorStore {
74    /// Create a new vector store for a namespace
75    pub fn new(namespace: &str) -> Result<Self> {
76        // Try to get storage path from environment
77        let storage_path = std::env::var("GRAPH_STORAGE_PATH")
78            .ok()
79            .map(|p| PathBuf::from(p).join(namespace));
80
81        // Get dimensions from env or default
82        let dimensions = std::env::var("VECTOR_DIMENSIONS")
83            .ok()
84            .and_then(|s| s.parse().ok())
85            .unwrap_or(DEFAULT_DIMENSIONS);
86        
87        // Create HNSW index with Euclidian metric
88        let mut index = Hnsw::new(Euclidian);
89        let mut id_to_uri = HashMap::new();
90        let mut uri_to_id = HashMap::new();
91        let mut embeddings = Vec::new();
92
93        // Try to load persisted vectors
94        if let Some(ref path) = storage_path {
95            let vectors_path = path.join("vectors.json");
96            if vectors_path.exists() {
97                if let Ok(content) = std::fs::read_to_string(&vectors_path) {
98                    if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
99                        let mut searcher = hnsw::Searcher::default();
100                        for entry in data.entries {
101                            if entry.embedding.len() == dimensions {
102                                let id = index.insert(entry.embedding.clone(), &mut searcher);
103                                id_to_uri.insert(id, entry.uri.clone());
104                                uri_to_id.insert(entry.uri.clone(), id);
105                                embeddings.push(entry);
106                            }
107                        }
108                        eprintln!("Loaded {} vectors from disk (dim={})", embeddings.len(), dimensions);
109                    }
110                }
111            }
112        }
113
114        // Get API token from environment (optional)
115        let api_token = std::env::var("HUGGINGFACE_API_TOKEN").ok();
116
117        // Configured client with timeout
118        let client = Client::builder()
119            .timeout(std::time::Duration::from_secs(30))
120            .build()
121            .unwrap_or_else(|_| Client::new());
122
123        Ok(Self {
124            index: Arc::new(RwLock::new(index)),
125            id_to_uri: Arc::new(RwLock::new(id_to_uri)),
126            uri_to_id: Arc::new(RwLock::new(uri_to_id)),
127            storage_path,
128            client,
129            api_token,
130            model: DEFAULT_MODEL.to_string(),
131            dimensions,
132            embeddings: Arc::new(RwLock::new(embeddings)),
133        })
134    }
135
136    /// Save vectors to disk
137    fn save_vectors(&self) -> Result<()> {
138        if let Some(ref path) = self.storage_path {
139            std::fs::create_dir_all(path)?;
140            let data = VectorData {
141                entries: self.embeddings.read().unwrap().clone(),
142            };
143            let content = serde_json::to_string(&data)?;
144            std::fs::write(path.join("vectors.json"), content)?;
145        }
146        Ok(())
147    }
148
149    /// Generate embedding for a text using HuggingFace Inference API
150    pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
151        let embeddings = self.embed_batch(vec![text.to_string()]).await?;
152        Ok(embeddings[0].clone())
153    }
154
155    /// Generate embeddings for multiple texts using HuggingFace Inference API
156    pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
157        if texts.is_empty() {
158            return Ok(Vec::new());
159        }
160
161        let url = format!(
162            "{}/{}/pipeline/feature-extraction",
163            HUGGINGFACE_API_URL, self.model
164        );
165
166        // HuggingFace accepts array of strings for inputs
167        let mut request = self.client.post(&url).json(&serde_json::json!({
168            "inputs": texts,
169        }));
170
171        // Add auth token if available
172        if let Some(ref token) = self.api_token {
173            request = request.header("Authorization", format!("Bearer {}", token));
174        }
175
176        let response = request.send().await?;
177
178        if !response.status().is_success() {
179            let error_text = response.text().await?;
180            anyhow::bail!("HuggingFace API error: {}", error_text);
181        }
182
183        let response_json: serde_json::Value = response.json().await?;
184        let mut results = Vec::new();
185
186        if let Some(arr) = response_json.as_array() {
187            for item in arr {
188                let vec: Vec<f32> = serde_json::from_value(item.clone())
189                    .map_err(|e| anyhow::anyhow!("Failed to parse embedding: {}", e))?;
190
191                if vec.len() != self.dimensions {
192                    anyhow::bail!("Expected {} dimensions, got {}", self.dimensions, vec.len());
193                }
194                results.push(vec);
195            }
196        } else {
197            // Handle case where we sent 1 text and got single flat array [0.1, ...]
198            if texts.len() == 1 {
199                if let Ok(vec) = serde_json::from_value::<Vec<f32>>(response_json) {
200                    if vec.len() == self.dimensions {
201                        results.push(vec);
202                    }
203                }
204            }
205        }
206
207        if results.len() != texts.len() {
208            anyhow::bail!("Expected {} embeddings, got {}", texts.len(), results.len());
209        }
210
211        Ok(results)
212    }
213
214    /// Add a URI with its text content to the index
215    pub async fn add(&self, uri: &str, content: &str) -> Result<usize> {
216        let results = self.add_batch(vec![(uri.to_string(), content.to_string())]).await?;
217        Ok(results[0])
218    }
219
220    /// Add multiple URIs with their text content to the index
221    pub async fn add_batch(&self, items: Vec<(String, String)>) -> Result<Vec<usize>> {
222        // Filter out existing URIs
223        let mut new_items = Vec::new();
224        let mut result_ids = vec![0; items.len()];
225        let mut new_indices = Vec::new(); // Map index in new_items to index in items
226
227        {
228            let uri_map = self.uri_to_id.read().unwrap();
229            for (i, (uri, content)) in items.iter().enumerate() {
230                if let Some(&id) = uri_map.get(uri) {
231                    result_ids[i] = id;
232                } else {
233                    new_items.push(content.clone());
234                    new_indices.push(i);
235                }
236            }
237        }
238
239        if new_items.is_empty() {
240            return Ok(result_ids);
241        }
242
243        // Generate embeddings via HuggingFace API
244        let embeddings = self.embed_batch(new_items).await?;
245
246        let mut ids_to_add = Vec::new();
247        let mut searcher = hnsw::Searcher::default();
248
249        // Add to HNSW index and maps
250        {
251            let mut index = self.index.write().unwrap();
252            let mut uri_map = self.uri_to_id.write().unwrap();
253            let mut id_map = self.id_to_uri.write().unwrap();
254            let mut embs = self.embeddings.write().unwrap();
255
256            for (i, embedding) in embeddings.into_iter().enumerate() {
257                let original_idx = new_indices[i];
258                let uri = &items[original_idx].0;
259
260                // Double check if inserted in race condition
261                if let Some(&id) = uri_map.get(uri) {
262                    result_ids[original_idx] = id;
263                    continue;
264                }
265
266                let id = index.insert(embedding.clone(), &mut searcher);
267                uri_map.insert(uri.clone(), id);
268                id_map.insert(id, uri.clone());
269
270                embs.push(VectorEntry {
271                    uri: uri.clone(),
272                    embedding: embedding,
273                });
274
275                result_ids[original_idx] = id;
276                ids_to_add.push(id);
277            }
278        }
279
280        if !ids_to_add.is_empty() {
281            let _ = self.save_vectors(); // Best effort persistence
282        }
283
284        Ok(result_ids)
285    }
286
287    /// Search for similar vectors
288    pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
289        // Generate query embedding via HuggingFace API
290        let query_embedding = self.embed(query).await?;
291
292        // Search HNSW index
293        let mut searcher = hnsw::Searcher::default();
294        let mut neighbors = vec![
295            space::Neighbor {
296                index: 0,
297                distance: u32::MAX
298            };
299            k
300        ];
301
302        let found_neighbors = {
303            let index = self.index.read().unwrap();
304            index.nearest(&query_embedding, 50, &mut searcher, &mut neighbors)
305        };
306
307        // Convert to results
308        let id_map = self.id_to_uri.read().unwrap();
309        let results: Vec<SearchResult> = found_neighbors
310            .iter()
311            .filter_map(|neighbor| {
312                id_map.get(&neighbor.index).map(|uri| {
313                    let score_f32 = (neighbor.distance as f32) / 1_000_000.0;
314                    SearchResult {
315                        uri: uri.clone(),
316                        score: 1.0 / (1.0 + score_f32),
317                        content: uri.clone(),
318                    }
319                })
320            })
321            .collect();
322
323        Ok(results)
324    }
325
326    pub fn get_uri(&self, id: usize) -> Option<String> {
327        self.id_to_uri.read().unwrap().get(&id).cloned()
328    }
329
330    pub fn get_id(&self, uri: &str) -> Option<usize> {
331        self.uri_to_id.read().unwrap().get(uri).copied()
332    }
333
334    pub fn len(&self) -> usize {
335        self.uri_to_id.read().unwrap().len()
336    }
337
338    pub fn is_empty(&self) -> bool {
339        self.len() == 0
340    }
341
342    /// Compaction: rebuild index from stored embeddings, removing stale entries
343    pub fn compact(&self) -> Result<usize> {
344        let embeddings = self.embeddings.read().unwrap();
345        let current_uris: std::collections::HashSet<_> = self.uri_to_id.read().unwrap().keys().cloned().collect();
346        
347        // Safeguard: If uri_to_id is empty, avoid compaction unless embeddings is also empty
348        // to prevent accidental deletion if mappings were lost.
349        if current_uris.is_empty() && !embeddings.is_empty() {
350            // We return 0 and skip compaction to be safe
351            return Ok(0);
352        }
353
354        // Filter to only current URIs
355        let active_entries: Vec<_> = embeddings
356            .iter()
357            .filter(|e| current_uris.contains(&e.uri))
358            .cloned()
359            .collect();
360
361        let removed = embeddings.len() - active_entries.len();
362
363        if removed == 0 {
364            return Ok(0);
365        }
366
367        // Rebuild index
368        let mut new_index = hnsw::Hnsw::new(Euclidian);
369        let mut new_id_to_uri = std::collections::HashMap::new();
370        let mut new_uri_to_id = std::collections::HashMap::new();
371        let mut searcher = hnsw::Searcher::default();
372
373        for entry in &active_entries {
374            if entry.embedding.len() == self.dimensions {
375                let id = new_index.insert(entry.embedding.clone(), &mut searcher);
376                new_id_to_uri.insert(id, entry.uri.clone());
377                new_uri_to_id.insert(entry.uri.clone(), id);
378            }
379        }
380
381        // Swap in new index
382        *self.index.write().unwrap() = new_index;
383        *self.id_to_uri.write().unwrap() = new_id_to_uri;
384        *self.uri_to_id.write().unwrap() = new_uri_to_id;
385
386        // Update embeddings (drop takes write lock)
387        drop(embeddings);
388        *self.embeddings.write().unwrap() = active_entries;
389
390        let _ = self.save_vectors();
391
392        Ok(removed)
393    }
394
395    /// Remove a URI from the vector store
396    pub fn remove(&self, uri: &str) -> bool {
397        let mut uri_map = self.uri_to_id.write().unwrap();
398        let mut id_map = self.id_to_uri.write().unwrap();
399
400        if let Some(id) = uri_map.remove(uri) {
401            id_map.remove(&id);
402            // Note: actual index entry remains until compaction
403            true
404        } else {
405            false
406        }
407    }
408
409    /// Get storage stats
410    pub fn stats(&self) -> (usize, usize, usize) {
411        let embeddings_count = self.embeddings.read().unwrap().len();
412        let active_count = self.uri_to_id.read().unwrap().len();
413        let stale_count = embeddings_count.saturating_sub(active_count);
414        (active_count, stale_count, embeddings_count)
415    }
416}