synapse_core/
vector_store.rs1use anyhow::Result;
2use hnsw::Hnsw;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::{Arc, RwLock};
7use std::path::PathBuf;
8use rand_pcg::Pcg64;
9
10const HUGGINGFACE_API_URL: &str = "https://router.huggingface.co/hf-inference/models";
11const DEFAULT_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2"; #[derive(Default, Clone)]
15pub struct Euclidian;
16
17impl space::Metric<[f32; 384]> for Euclidian {
18 type Unit = u64;
19 fn distance(&self, a: &[f32; 384], b: &[f32; 384]) -> u64 {
20 let mut dist_sq = 0.0;
21 for i in 0..384 {
22 let diff = a[i] - b[i];
23 dist_sq += diff * diff;
24 }
25 dist_sq.sqrt().to_bits() as u64
27 }
28}
29
30#[derive(Serialize, Deserialize, Default)]
32struct VectorData {
33 entries: Vec<VectorEntry>,
34}
35
36#[derive(Serialize, Deserialize, Clone)]
37struct VectorEntry {
38 uri: String,
39 embedding: Vec<f32>,
40}
41
42pub struct VectorStore {
44 index: Arc<RwLock<Hnsw<Euclidian, [f32; 384], Pcg64, 16, 32>>>,
46 id_to_uri: Arc<RwLock<HashMap<usize, String>>>,
48 uri_to_id: Arc<RwLock<HashMap<String, usize>>>,
50 storage_path: Option<PathBuf>,
52 client: Client,
54 api_token: Option<String>,
56 model: String,
58 embeddings: Arc<RwLock<Vec<VectorEntry>>>,
60}
61
62#[derive(Debug, Serialize, Deserialize)]
63pub struct SearchResult {
64 pub uri: String,
65 pub score: f32,
66 pub content: String,
67}
68
69#[derive(Serialize)]
70struct EmbeddingRequest {
71 inputs: String,
72}
73
74impl VectorStore {
75 pub fn new(namespace: &str) -> Result<Self> {
77 let storage_path = std::env::var("GRAPH_STORAGE_PATH")
79 .ok()
80 .map(|p| PathBuf::from(p).join(namespace));
81
82 let mut index = Hnsw::new(Euclidian);
84 let mut id_to_uri = HashMap::new();
85 let mut uri_to_id = HashMap::new();
86 let mut embeddings = Vec::new();
87
88 if let Some(ref path) = storage_path {
90 let vectors_path = path.join("vectors.json");
91 if vectors_path.exists() {
92 if let Ok(content) = std::fs::read_to_string(&vectors_path) {
93 if let Ok(data) = serde_json::from_str::<VectorData>(&content) {
94 let mut searcher = hnsw::Searcher::default();
95 for entry in data.entries {
96 if entry.embedding.len() == 384 {
97 let mut emb = [0.0f32; 384];
98 emb.copy_from_slice(&entry.embedding);
99 let id = index.insert(emb, &mut searcher);
100 id_to_uri.insert(id, entry.uri.clone());
101 uri_to_id.insert(entry.uri.clone(), id);
102 embeddings.push(entry);
103 }
104 }
105 eprintln!("Loaded {} vectors from disk", embeddings.len());
106 }
107 }
108 }
109 }
110
111 let api_token = std::env::var("HUGGINGFACE_API_TOKEN").ok();
113
114 Ok(Self {
115 index: Arc::new(RwLock::new(index)),
116 id_to_uri: Arc::new(RwLock::new(id_to_uri)),
117 uri_to_id: Arc::new(RwLock::new(uri_to_id)),
118 storage_path,
119 client: Client::new(),
120 api_token,
121 model: DEFAULT_MODEL.to_string(),
122 embeddings: Arc::new(RwLock::new(embeddings)),
123 })
124 }
125
126 fn save_vectors(&self) -> Result<()> {
128 if let Some(ref path) = self.storage_path {
129 std::fs::create_dir_all(path)?;
130 let data = VectorData {
131 entries: self.embeddings.read().unwrap().clone(),
132 };
133 let content = serde_json::to_string(&data)?;
134 std::fs::write(path.join("vectors.json"), content)?;
135 }
136 Ok(())
137 }
138
139 pub async fn embed(&self, text: &str) -> Result<[f32; 384]> {
141 let url = format!("{}/{}/pipeline/feature-extraction", HUGGINGFACE_API_URL, self.model);
142
143 let mut request = self.client
144 .post(&url)
145 .json(&EmbeddingRequest {
146 inputs: text.to_string(),
147 });
148
149 if let Some(ref token) = self.api_token {
151 request = request.header("Authorization", format!("Bearer {}", token));
152 }
153
154 let response = request.send().await?;
155
156 if !response.status().is_success() {
157 let error_text = response.text().await?;
158 anyhow::bail!("HuggingFace API error: {}", error_text);
159 }
160
161 let embedding_vec: Vec<f32> = response.json().await?;
163
164 if embedding_vec.len() != 384 {
165 anyhow::bail!("Expected 384 dimensions, got {}", embedding_vec.len());
166 }
167
168 let mut embedding = [0.0f32; 384];
169 embedding.copy_from_slice(&embedding_vec[0..384]);
170
171 Ok(embedding)
172 }
173
174 pub async fn add(&self, uri: &str, content: &str) -> Result<usize> {
176 {
178 let uri_map = self.uri_to_id.read().unwrap();
179 if let Some(&id) = uri_map.get(uri) {
180 return Ok(id);
181 }
182 }
183
184 let embedding = self.embed(content).await?;
186
187 let mut searcher = hnsw::Searcher::default();
189 let id = {
190 let mut index = self.index.write().unwrap();
191 index.insert(embedding, &mut searcher)
192 };
193
194 {
196 let mut uri_map = self.uri_to_id.write().unwrap();
197 let mut id_map = self.id_to_uri.write().unwrap();
198 uri_map.insert(uri.to_string(), id);
199 id_map.insert(id, uri.to_string());
200 }
201
202 {
204 let mut embs = self.embeddings.write().unwrap();
205 embs.push(VectorEntry {
206 uri: uri.to_string(),
207 embedding: embedding.to_vec(),
208 });
209 }
210 let _ = self.save_vectors(); Ok(id)
213 }
214
215 pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
217 let query_embedding = self.embed(query).await?;
219
220 let mut searcher = hnsw::Searcher::default();
222 let mut neighbors = vec![space::Neighbor { index: 0, distance: !0 }; k];
223
224 let found_neighbors = {
225 let index = self.index.read().unwrap();
226 index.nearest(&query_embedding, 50, &mut searcher, &mut neighbors)
227 };
228
229 let id_map = self.id_to_uri.read().unwrap();
231 let results: Vec<SearchResult> = found_neighbors
232 .iter()
233 .filter_map(|neighbor| {
234 id_map.get(&neighbor.index).map(|uri| {
235 let score_f32 = f32::from_bits(neighbor.distance as u32);
237 SearchResult {
238 uri: uri.clone(),
239 score: 1.0 / (1.0 + score_f32),
240 content: uri.clone(),
241 }
242 })
243 })
244 .collect();
245
246 Ok(results)
247 }
248
249 pub fn get_uri(&self, id: usize) -> Option<String> {
250 self.id_to_uri.read().unwrap().get(&id).cloned()
251 }
252
253 pub fn get_id(&self, uri: &str) -> Option<usize> {
254 self.uri_to_id.read().unwrap().get(uri).copied()
255 }
256
257 pub fn len(&self) -> usize {
258 self.uri_to_id.read().unwrap().len()
259 }
260
261 pub fn is_empty(&self) -> bool {
262 self.len() == 0
263 }
264
265 pub fn compact(&self) -> Result<usize> {
267 let embeddings = self.embeddings.read().unwrap();
268 let current_uris: std::collections::HashSet<_> = self.uri_to_id.read().unwrap().keys().cloned().collect();
269
270 let active_entries: Vec<_> = embeddings.iter()
272 .filter(|e| current_uris.contains(&e.uri))
273 .cloned()
274 .collect();
275
276 let removed = embeddings.len() - active_entries.len();
277
278 if removed == 0 {
279 return Ok(0);
280 }
281
282 let mut new_index = hnsw::Hnsw::new(Euclidian);
284 let mut new_id_to_uri = std::collections::HashMap::new();
285 let mut new_uri_to_id = std::collections::HashMap::new();
286 let mut searcher = hnsw::Searcher::default();
287
288 for entry in &active_entries {
289 if entry.embedding.len() == 384 {
290 let mut emb = [0.0f32; 384];
291 emb.copy_from_slice(&entry.embedding);
292 let id = new_index.insert(emb, &mut searcher);
293 new_id_to_uri.insert(id, entry.uri.clone());
294 new_uri_to_id.insert(entry.uri.clone(), id);
295 }
296 }
297
298 *self.index.write().unwrap() = new_index;
300 *self.id_to_uri.write().unwrap() = new_id_to_uri;
301 *self.uri_to_id.write().unwrap() = new_uri_to_id;
302
303 drop(embeddings);
305 *self.embeddings.write().unwrap() = active_entries;
306
307 let _ = self.save_vectors();
308
309 Ok(removed)
310 }
311
312 pub fn remove(&self, uri: &str) -> bool {
314 let mut uri_map = self.uri_to_id.write().unwrap();
315 let mut id_map = self.id_to_uri.write().unwrap();
316
317 if let Some(id) = uri_map.remove(uri) {
318 id_map.remove(&id);
319 true
321 } else {
322 false
323 }
324 }
325
326 pub fn stats(&self) -> (usize, usize, usize) {
328 let embeddings_count = self.embeddings.read().unwrap().len();
329 let active_count = self.uri_to_id.read().unwrap().len();
330 let stale_count = embeddings_count.saturating_sub(active_count);
331 (active_count, stale_count, embeddings_count)
332 }
333}