1use crate::db;
7use crate::error::MemoryError;
8use hnsw_rs::prelude::*;
9use rusqlite::params;
10use std::collections::{HashMap, HashSet};
11use std::fs::File;
12use std::io::{Read, Seek};
13use std::path::Path;
14use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
15use std::sync::{Arc, RwLock};
16
17const HNSW_DATA_MAGIC: u32 = 0xa67f0000;
18
19#[derive(Debug, Clone)]
21pub struct HnswConfig {
22 pub m: usize,
23 pub ef_construction: usize,
24 pub ef_search: usize,
25 pub dimensions: usize,
26 pub max_elements: usize,
27 pub compaction_threshold: f32,
28 pub flush_interval_secs: Option<u64>,
29}
30
31impl Default for HnswConfig {
32 fn default() -> Self {
33 Self {
34 m: 16,
35 ef_construction: 200,
36 ef_search: 50,
37 dimensions: 768,
38 max_elements: 100_000,
39 compaction_threshold: 0.3,
40 flush_interval_secs: None,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct HnswHit {
48 pub key: String,
49 pub distance: f32,
50}
51
52impl HnswHit {
53 pub fn similarity(&self) -> f32 {
54 (1.0 - self.distance).max(0.0)
55 }
56
57 pub fn parse_key(&self) -> Result<(&str, &str), MemoryError> {
59 self.key
60 .split_once(':')
61 .ok_or_else(|| MemoryError::InvalidKey(self.key.clone()))
62 }
63}
64
65struct HnswIndexInner {
66 graph: Hnsw<'static, f32, DistCosine>,
67 key_to_id: RwLock<HashMap<String, usize>>,
69 id_to_key: RwLock<HashMap<usize, String>>,
71 next_id: AtomicUsize,
72 deleted_ids: RwLock<HashSet<usize>>,
73 keymap_dirty: AtomicBool,
74 last_flush_epoch: AtomicU64,
75 config: HnswConfig,
76}
77
78fn current_epoch_secs() -> u64 {
79 std::time::SystemTime::now()
80 .duration_since(std::time::UNIX_EPOCH)
81 .unwrap_or_default()
82 .as_secs()
83}
84
85#[derive(Clone)]
86pub struct HnswIndex {
87 inner: Arc<HnswIndexInner>,
88}
89
90impl HnswIndex {
91 pub fn new(config: HnswConfig) -> Result<Self, MemoryError> {
92 let graph: Hnsw<'static, f32, DistCosine> = Hnsw::new(
93 config.m,
94 config.max_elements,
95 16,
96 config.ef_construction,
97 DistCosine {},
98 );
99
100 Ok(Self {
101 inner: Arc::new(HnswIndexInner {
102 graph,
103 key_to_id: RwLock::new(HashMap::new()),
105 id_to_key: RwLock::new(HashMap::new()),
107 next_id: AtomicUsize::new(0),
108 deleted_ids: RwLock::new(HashSet::new()),
109 keymap_dirty: AtomicBool::new(false),
110 last_flush_epoch: AtomicU64::new(current_epoch_secs()),
111 config,
112 }),
113 })
114 }
115
116 pub fn load(dir: &Path, basename: &str, config: HnswConfig) -> Result<Self, MemoryError> {
122 let data_path = dir.join(format!("{}.hnsw.data", basename));
123 let graph_path = dir.join(format!("{}.hnsw.graph", basename));
124 if !data_path.exists() || !graph_path.exists() {
125 return Err(MemoryError::HnswError(format!(
126 "missing HNSW sidecar files under {}",
127 dir.display()
128 )));
129 }
130
131 let index = Self::new(config)?;
132 validate_graph_sidecar(&graph_path)?;
133 let max_id = load_vectors_from_sidecar(&index, &data_path)?;
134 index
135 .inner
136 .next_id
137 .store(max_id.saturating_add(1), Ordering::SeqCst);
138 Ok(index)
139 }
140
141 pub fn save(&self, dir: &Path, basename: &str) -> Result<(), MemoryError> {
142 self.inner
143 .graph
144 .file_dump(dir, basename)
145 .map_err(|e| MemoryError::HnswError(format!("failed to save HNSW index: {}", e)))?;
146 Ok(())
147 }
148
149 pub fn insert(&self, key: String, vector: &[f32]) -> Result<(), MemoryError> {
150 let id = self.inner.next_id.fetch_add(1, Ordering::SeqCst);
151 self.insert_with_id(Some(key), id, vector)
152 }
153
154 pub fn delete(&self, key: &str) -> Result<(), MemoryError> {
155 let mut key_to_id = self
156 .inner
157 .key_to_id
158 .write()
159 .unwrap_or_else(|e| e.into_inner());
160 let mut id_to_key = self
161 .inner
162 .id_to_key
163 .write()
164 .unwrap_or_else(|e| e.into_inner());
165
166 if let Some(id) = key_to_id.remove(key) {
167 id_to_key.remove(&id);
168 self.inner
169 .deleted_ids
170 .write()
171 .unwrap_or_else(|e| e.into_inner())
172 .insert(id);
173 self.inner.keymap_dirty.store(true, Ordering::Release);
174 }
175 Ok(())
176 }
177
178 pub fn update(&self, key: String, vector: &[f32]) -> Result<(), MemoryError> {
179 self.delete(&key)?;
180 self.insert(key, vector)
181 }
182
183 pub fn search(&self, query: &[f32], top_k: usize) -> Result<Vec<HnswHit>, MemoryError> {
184 validate_dimensions(query, self.inner.config.dimensions)?;
185
186 if self.is_empty() || top_k == 0 {
187 return Ok(Vec::new());
188 }
189
190 let deleted_snapshot = self
191 .inner
192 .deleted_ids
193 .read()
194 .unwrap_or_else(|e| e.into_inner())
195 .clone();
196 let total_points = self.inner.graph.get_nb_point();
197 let fetch_count = top_k
198 .saturating_add(deleted_snapshot.len())
199 .min(total_points);
200
201 let neighbors = self
202 .inner
203 .graph
204 .search(query, fetch_count, self.inner.config.ef_search);
205
206 let id_to_key = self
207 .inner
208 .id_to_key
209 .read()
210 .unwrap_or_else(|e| e.into_inner());
211
212 let mut hits: Vec<HnswHit> = neighbors
213 .into_iter()
214 .filter(|neighbor| !deleted_snapshot.contains(&neighbor.d_id))
215 .filter_map(|neighbor| {
216 id_to_key.get(&neighbor.d_id).map(|key| HnswHit {
217 key: key.clone(),
218 distance: neighbor.distance,
219 })
220 })
221 .take(top_k)
222 .collect();
223
224 hits.sort_by(|a, b| {
225 a.distance.partial_cmp(&b.distance).unwrap_or_else(|| {
226 if a.distance.is_nan() {
228 std::cmp::Ordering::Greater
229 } else {
230 std::cmp::Ordering::Less
231 }
232 })
233 });
234 Ok(hits)
235 }
236
237 pub fn len(&self) -> usize {
238 let total = self.inner.graph.get_nb_point();
239 let deleted = self
240 .inner
241 .deleted_ids
242 .read()
243 .unwrap_or_else(|e| e.into_inner())
244 .len();
245 total.saturating_sub(deleted)
246 }
247
248 pub fn is_empty(&self) -> bool {
249 self.len() == 0
250 }
251
252 pub fn deleted_ratio(&self) -> f32 {
253 let total = self.inner.graph.get_nb_point();
254 if total == 0 {
255 return 0.0;
256 }
257 let deleted = self
258 .inner
259 .deleted_ids
260 .read()
261 .unwrap_or_else(|e| e.into_inner())
262 .len();
263 deleted as f32 / total as f32
264 }
265
266 pub fn needs_compaction(&self) -> bool {
267 self.deleted_ratio() > self.inner.config.compaction_threshold
268 }
269
270 pub fn config(&self) -> &HnswConfig {
271 &self.inner.config
272 }
273
274 pub fn is_keymap_dirty(&self) -> bool {
275 self.inner.keymap_dirty.load(Ordering::Acquire)
276 }
277
278 pub fn should_flush(&self, interval_secs: u64) -> bool {
279 let last = self.inner.last_flush_epoch.load(Ordering::Relaxed);
280 current_epoch_secs().saturating_sub(last) >= interval_secs
281 }
282
283 pub fn update_last_flush_epoch(&self) {
284 self.inner
285 .last_flush_epoch
286 .store(current_epoch_secs(), Ordering::Relaxed);
287 }
288
289 pub fn flush_keymap(&self, conn: &rusqlite::Connection) -> Result<(), MemoryError> {
290 if !self.is_keymap_dirty() {
291 return Ok(());
292 }
293
294 let key_to_id = self
295 .inner
296 .key_to_id
297 .read()
298 .unwrap_or_else(|e| e.into_inner());
299 let deleted = self
300 .inner
301 .deleted_ids
302 .read()
303 .unwrap_or_else(|e| e.into_inner());
304 let next_id = self.inner.next_id.load(Ordering::SeqCst);
305
306 db::with_transaction(conn, |tx| {
307 tx.execute("DELETE FROM hnsw_keymap", [])?;
308 let mut insert_stmt = tx.prepare(
309 "INSERT INTO hnsw_keymap (node_id, item_key, deleted) VALUES (?1, ?2, ?3)",
310 )?;
311
312 for (key, id) in key_to_id.iter() {
313 insert_stmt.execute(params![*id as i64, key, 0])?;
314 }
315 for id in deleted.iter() {
316 insert_stmt.execute(params![*id as i64, format!("_deleted:{}", id), 1])?;
317 }
318 drop(insert_stmt);
319
320 tx.execute(
321 "INSERT INTO hnsw_metadata (key, value) VALUES ('next_id', ?1)
322 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
323 params![next_id.to_string()],
324 )?;
325 Ok(())
326 })?;
327
328 self.inner.keymap_dirty.store(false, Ordering::Release);
329 Ok(())
330 }
331
332 pub fn load_keymap(&self, conn: &rusqlite::Connection) -> Result<(), MemoryError> {
333 let table_exists: bool = conn
334 .query_row(
335 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='hnsw_keymap'",
336 [],
337 |row| row.get(0),
338 )
339 .unwrap_or(false);
340 if !table_exists {
341 tracing::warn!("hnsw_keymap table missing; HNSW keys will remain empty until rebuild");
342 return Ok(());
343 }
344
345 let mut key_to_id = HashMap::new();
347 let mut id_to_key = HashMap::new();
349 let mut deleted_ids = HashSet::new();
350
351 let mut stmt = conn.prepare("SELECT node_id, item_key, deleted FROM hnsw_keymap")?;
352 let rows = stmt.query_map([], |row| {
353 Ok((
354 row.get::<_, i64>(0)? as usize,
355 row.get::<_, String>(1)?,
356 row.get::<_, bool>(2)?,
357 ))
358 })?;
359
360 for row in rows {
361 let (node_id, item_key, deleted) = row?;
362 if node_id >= self.inner.next_id.load(Ordering::SeqCst) {
363 return Err(MemoryError::HnswError(format!(
364 "hnsw_keymap node_id {node_id} is outside loaded HNSW sidecar bounds"
365 )));
366 }
367 if deleted {
368 deleted_ids.insert(node_id);
369 } else {
370 key_to_id.insert(item_key.clone(), node_id);
371 id_to_key.insert(node_id, item_key);
372 }
373 }
374
375 let next_id = conn
376 .query_row(
377 "SELECT value FROM hnsw_metadata WHERE key = 'next_id'",
378 [],
379 |row| row.get::<_, String>(0),
380 )
381 .ok()
382 .and_then(|value| value.parse::<usize>().ok())
383 .unwrap_or_else(|| self.inner.graph.get_nb_point());
384
385 *self
386 .inner
387 .key_to_id
388 .write()
389 .unwrap_or_else(|e| e.into_inner()) = key_to_id;
390 *self
391 .inner
392 .id_to_key
393 .write()
394 .unwrap_or_else(|e| e.into_inner()) = id_to_key;
395 *self
396 .inner
397 .deleted_ids
398 .write()
399 .unwrap_or_else(|e| e.into_inner()) = deleted_ids;
400 self.inner.next_id.store(next_id, Ordering::SeqCst);
401 self.inner.keymap_dirty.store(false, Ordering::Release);
402
403 Ok(())
404 }
405
406 fn insert_with_id(
407 &self,
408 key: Option<String>,
409 id: usize,
410 vector: &[f32],
411 ) -> Result<(), MemoryError> {
412 validate_dimensions(vector, self.inner.config.dimensions)?;
413
414 if let Some(key) = key {
415 self.inner.graph.insert((vector, id));
416
417 let mut key_to_id = self
418 .inner
419 .key_to_id
420 .write()
421 .unwrap_or_else(|e| e.into_inner());
422 let mut id_to_key = self
423 .inner
424 .id_to_key
425 .write()
426 .unwrap_or_else(|e| e.into_inner());
427
428 if let Some(old_id) = key_to_id.insert(key.clone(), id) {
429 id_to_key.remove(&old_id);
430 self.inner
431 .deleted_ids
432 .write()
433 .unwrap_or_else(|e| e.into_inner())
434 .insert(old_id);
435 }
436 id_to_key.insert(id, key);
437 self.inner.keymap_dirty.store(true, Ordering::Release);
438 } else {
439 self.inner.graph.insert((vector, id));
440 }
441 Ok(())
442 }
443}
444
445fn validate_dimensions(vector: &[f32], expected: usize) -> Result<(), MemoryError> {
446 if vector.len() != expected {
447 return Err(MemoryError::HnswError(format!(
448 "expected {} dimensions, got {}",
449 expected,
450 vector.len()
451 )));
452 }
453 if vector.iter().any(|v| !v.is_finite()) {
455 return Err(MemoryError::HnswError(
456 "embedding contains NaN or infinity values".into(),
457 ));
458 }
459 Ok(())
460}
461
462fn validate_graph_sidecar(graph_path: &Path) -> Result<(), MemoryError> {
463 let mut file = File::open(graph_path).map_err(|e| {
464 MemoryError::HnswError(format!("failed to open {}: {}", graph_path.display(), e))
465 })?;
466 let len = file.seek(std::io::SeekFrom::End(0)).map_err(|e| {
467 MemoryError::HnswError(format!("failed to inspect {}: {}", graph_path.display(), e))
468 })?;
469 if len == 0 {
470 return Err(MemoryError::HnswError(format!(
471 "empty HNSW graph sidecar: {}",
472 graph_path.display()
473 )));
474 }
475 Ok(())
476}
477
478fn load_vectors_from_sidecar(index: &HnswIndex, data_path: &Path) -> Result<usize, MemoryError> {
479 let mut file = File::open(data_path).map_err(|e| {
480 MemoryError::HnswError(format!("failed to open {}: {}", data_path.display(), e))
481 })?;
482
483 let mut u32_buf = [0u8; 4];
484 file.read_exact(&mut u32_buf).map_err(|e| {
485 MemoryError::HnswError(format!("failed to read HNSW sidecar header: {}", e))
486 })?;
487 if u32::from_le_bytes(u32_buf) != HNSW_DATA_MAGIC {
488 return Err(MemoryError::HnswError(
489 "invalid HNSW data file magic header".to_string(),
490 ));
491 }
492
493 let mut usize_buf = [0u8; std::mem::size_of::<usize>()];
494 file.read_exact(&mut usize_buf).map_err(|e| {
495 MemoryError::HnswError(format!("failed to read HNSW sidecar dimensions: {}", e))
496 })?;
497 let dims = usize::from_le_bytes(usize_buf);
498 if dims != index.inner.config.dimensions {
499 return Err(MemoryError::HnswError(format!(
500 "HNSW sidecar dimensions {} do not match configured {}",
501 dims, index.inner.config.dimensions
502 )));
503 }
504
505 let mut max_id = 0usize;
506
507 loop {
508 match file.read_exact(&mut u32_buf) {
509 Ok(()) => {}
510 Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => break,
511 Err(err) => {
512 return Err(MemoryError::HnswError(format!(
513 "failed while reading HNSW sidecar entry header: {}",
514 err
515 )))
516 }
517 }
518
519 if u32::from_le_bytes(u32_buf) != HNSW_DATA_MAGIC {
520 return Err(MemoryError::HnswError(
521 "invalid per-vector HNSW data magic".to_string(),
522 ));
523 }
524
525 let mut u64_buf = [0u8; 8];
526 file.read_exact(&mut u64_buf).map_err(|e| {
527 MemoryError::HnswError(format!("failed to read HNSW sidecar node id: {}", e))
528 })?;
529 let id = u64::from_le_bytes(u64_buf) as usize;
530
531 file.read_exact(&mut u64_buf).map_err(|e| {
532 MemoryError::HnswError(format!("failed to read HNSW sidecar vector length: {}", e))
533 })?;
534 let byte_len = u64::from_le_bytes(u64_buf) as usize;
535 let mut raw = vec![0u8; byte_len];
536 file.read_exact(&mut raw).map_err(|e| {
537 MemoryError::HnswError(format!("failed to read HNSW sidecar payload: {}", e))
538 })?;
539
540 let vector = db::bytes_to_embedding(&raw)?;
541 index.insert_with_id(None, id, &vector)?;
542 max_id = max_id.max(id);
543 }
544
545 Ok(max_id)
546}