Skip to main content

semantic_memory/
hnsw.rs

1//! HNSW approximate nearest-neighbor index wrapper.
2//!
3//! SQLite remains the source of truth. The on-disk HNSW files are a recoverable
4//! acceleration sidecar that can be rebuilt from SQLite whenever needed.
5
6use crate::db;
7use crate::error::MemoryError;
8use hnsw_rs::prelude::*;
9use rusqlite::params;
10use std::collections::{HashMap, HashSet};
11use std::fs::File;
12use std::io::{Read, Seek};
13use std::path::Path;
14use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
15use std::sync::{Arc, RwLock};
16
17const HNSW_DATA_MAGIC: u32 = 0xa67f0000;
18
19/// Configuration for the HNSW index.
20#[derive(Debug, Clone)]
21pub struct HnswConfig {
22    pub m: usize,
23    pub ef_construction: usize,
24    pub ef_search: usize,
25    pub dimensions: usize,
26    pub max_elements: usize,
27    pub compaction_threshold: f32,
28    pub flush_interval_secs: Option<u64>,
29}
30
31impl Default for HnswConfig {
32    fn default() -> Self {
33        Self {
34            m: 16,
35            ef_construction: 200,
36            ef_search: 50,
37            dimensions: 768,
38            max_elements: 100_000,
39            compaction_threshold: 0.3,
40            flush_interval_secs: None,
41        }
42    }
43}
44
45/// A single hit from HNSW search.
46#[derive(Debug, Clone)]
47pub struct HnswHit {
48    pub key: String,
49    pub distance: f32,
50}
51
52impl HnswHit {
53    pub fn similarity(&self) -> f32 {
54        (1.0 - self.distance).max(0.0)
55    }
56
57    /// Split the sidecar key into `(domain, identifier)`.
58    pub fn parse_key(&self) -> Result<(&str, &str), MemoryError> {
59        self.key
60            .split_once(':')
61            .ok_or_else(|| MemoryError::InvalidKey(self.key.clone()))
62    }
63}
64
65struct HnswIndexInner {
66    graph: Hnsw<'static, f32, DistCosine>,
67    // CONVENTION EXCEPTION: O(1) lookup required for HNSW index
68    key_to_id: RwLock<HashMap<String, usize>>,
69    // CONVENTION EXCEPTION: O(1) lookup required for HNSW index
70    id_to_key: RwLock<HashMap<usize, String>>,
71    next_id: AtomicUsize,
72    deleted_ids: RwLock<HashSet<usize>>,
73    keymap_dirty: AtomicBool,
74    last_flush_epoch: AtomicU64,
75    config: HnswConfig,
76}
77
78fn current_epoch_secs() -> u64 {
79    std::time::SystemTime::now()
80        .duration_since(std::time::UNIX_EPOCH)
81        .unwrap_or_default()
82        .as_secs()
83}
84
85#[derive(Clone)]
86pub struct HnswIndex {
87    inner: Arc<HnswIndexInner>,
88}
89
90impl HnswIndex {
91    pub fn new(config: HnswConfig) -> Result<Self, MemoryError> {
92        let graph: Hnsw<'static, f32, DistCosine> = Hnsw::new(
93            config.m,
94            config.max_elements,
95            16,
96            config.ef_construction,
97            DistCosine {},
98        );
99
100        Ok(Self {
101            inner: Arc::new(HnswIndexInner {
102                graph,
103                // CONVENTION EXCEPTION: O(1) lookup required for HNSW index
104                key_to_id: RwLock::new(HashMap::new()),
105                // CONVENTION EXCEPTION: O(1) lookup required for HNSW index
106                id_to_key: RwLock::new(HashMap::new()),
107                next_id: AtomicUsize::new(0),
108                deleted_ids: RwLock::new(HashSet::new()),
109                keymap_dirty: AtomicBool::new(false),
110                last_flush_epoch: AtomicU64::new(current_epoch_secs()),
111                config,
112            }),
113        })
114    }
115
116    /// Load a previously flushed HNSW sidecar by replaying the dumped vectors.
117    ///
118    /// This avoids relying on `hnsw_rs`'s borrowing reload API and keeps the safety
119    /// boundary purely in safe Rust. Node IDs are preserved so the SQLite keymap can
120    /// be loaded afterward.
121    pub fn load(dir: &Path, basename: &str, config: HnswConfig) -> Result<Self, MemoryError> {
122        let data_path = dir.join(format!("{}.hnsw.data", basename));
123        let graph_path = dir.join(format!("{}.hnsw.graph", basename));
124        if !data_path.exists() || !graph_path.exists() {
125            return Err(MemoryError::HnswError(format!(
126                "missing HNSW sidecar files under {}",
127                dir.display()
128            )));
129        }
130
131        let index = Self::new(config)?;
132        validate_graph_sidecar(&graph_path)?;
133        let max_id = load_vectors_from_sidecar(&index, &data_path)?;
134        index
135            .inner
136            .next_id
137            .store(max_id.saturating_add(1), Ordering::SeqCst);
138        Ok(index)
139    }
140
141    pub fn save(&self, dir: &Path, basename: &str) -> Result<(), MemoryError> {
142        self.inner
143            .graph
144            .file_dump(dir, basename)
145            .map_err(|e| MemoryError::HnswError(format!("failed to save HNSW index: {}", e)))?;
146        Ok(())
147    }
148
149    pub fn insert(&self, key: String, vector: &[f32]) -> Result<(), MemoryError> {
150        let id = self.inner.next_id.fetch_add(1, Ordering::SeqCst);
151        self.insert_with_id(Some(key), id, vector)
152    }
153
154    pub fn delete(&self, key: &str) -> Result<(), MemoryError> {
155        let mut key_to_id = self
156            .inner
157            .key_to_id
158            .write()
159            .unwrap_or_else(|e| e.into_inner());
160        let mut id_to_key = self
161            .inner
162            .id_to_key
163            .write()
164            .unwrap_or_else(|e| e.into_inner());
165
166        if let Some(id) = key_to_id.remove(key) {
167            id_to_key.remove(&id);
168            self.inner
169                .deleted_ids
170                .write()
171                .unwrap_or_else(|e| e.into_inner())
172                .insert(id);
173            self.inner.keymap_dirty.store(true, Ordering::Release);
174        }
175        Ok(())
176    }
177
178    pub fn update(&self, key: String, vector: &[f32]) -> Result<(), MemoryError> {
179        self.delete(&key)?;
180        self.insert(key, vector)
181    }
182
183    pub fn search(&self, query: &[f32], top_k: usize) -> Result<Vec<HnswHit>, MemoryError> {
184        validate_dimensions(query, self.inner.config.dimensions)?;
185
186        if self.is_empty() || top_k == 0 {
187            return Ok(Vec::new());
188        }
189
190        let deleted_snapshot = self
191            .inner
192            .deleted_ids
193            .read()
194            .unwrap_or_else(|e| e.into_inner())
195            .clone();
196        let total_points = self.inner.graph.get_nb_point();
197        let fetch_count = top_k
198            .saturating_add(deleted_snapshot.len())
199            .min(total_points);
200
201        let neighbors = self
202            .inner
203            .graph
204            .search(query, fetch_count, self.inner.config.ef_search);
205
206        let id_to_key = self
207            .inner
208            .id_to_key
209            .read()
210            .unwrap_or_else(|e| e.into_inner());
211
212        let mut hits: Vec<HnswHit> = neighbors
213            .into_iter()
214            .filter(|neighbor| !deleted_snapshot.contains(&neighbor.d_id))
215            .filter_map(|neighbor| {
216                id_to_key.get(&neighbor.d_id).map(|key| HnswHit {
217                    key: key.clone(),
218                    distance: neighbor.distance,
219                })
220            })
221            .take(top_k)
222            .collect();
223
224        hits.sort_by(|a, b| {
225            a.distance.partial_cmp(&b.distance).unwrap_or_else(|| {
226                // LIB-020: NaN distances sort to the end rather than comparing as equal
227                if a.distance.is_nan() {
228                    std::cmp::Ordering::Greater
229                } else {
230                    std::cmp::Ordering::Less
231                }
232            })
233        });
234        Ok(hits)
235    }
236
237    pub fn len(&self) -> usize {
238        let total = self.inner.graph.get_nb_point();
239        let deleted = self
240            .inner
241            .deleted_ids
242            .read()
243            .unwrap_or_else(|e| e.into_inner())
244            .len();
245        total.saturating_sub(deleted)
246    }
247
248    pub fn is_empty(&self) -> bool {
249        self.len() == 0
250    }
251
252    pub fn deleted_ratio(&self) -> f32 {
253        let total = self.inner.graph.get_nb_point();
254        if total == 0 {
255            return 0.0;
256        }
257        let deleted = self
258            .inner
259            .deleted_ids
260            .read()
261            .unwrap_or_else(|e| e.into_inner())
262            .len();
263        deleted as f32 / total as f32
264    }
265
266    pub fn needs_compaction(&self) -> bool {
267        self.deleted_ratio() > self.inner.config.compaction_threshold
268    }
269
270    pub fn config(&self) -> &HnswConfig {
271        &self.inner.config
272    }
273
274    pub fn is_keymap_dirty(&self) -> bool {
275        self.inner.keymap_dirty.load(Ordering::Acquire)
276    }
277
278    pub fn should_flush(&self, interval_secs: u64) -> bool {
279        let last = self.inner.last_flush_epoch.load(Ordering::Relaxed);
280        current_epoch_secs().saturating_sub(last) >= interval_secs
281    }
282
283    pub fn update_last_flush_epoch(&self) {
284        self.inner
285            .last_flush_epoch
286            .store(current_epoch_secs(), Ordering::Relaxed);
287    }
288
289    pub fn flush_keymap(&self, conn: &rusqlite::Connection) -> Result<(), MemoryError> {
290        if !self.is_keymap_dirty() {
291            return Ok(());
292        }
293
294        let key_to_id = self
295            .inner
296            .key_to_id
297            .read()
298            .unwrap_or_else(|e| e.into_inner());
299        let deleted = self
300            .inner
301            .deleted_ids
302            .read()
303            .unwrap_or_else(|e| e.into_inner());
304        let next_id = self.inner.next_id.load(Ordering::SeqCst);
305
306        db::with_transaction(conn, |tx| {
307            tx.execute("DELETE FROM hnsw_keymap", [])?;
308            let mut insert_stmt = tx.prepare(
309                "INSERT INTO hnsw_keymap (node_id, item_key, deleted) VALUES (?1, ?2, ?3)",
310            )?;
311
312            for (key, id) in key_to_id.iter() {
313                insert_stmt.execute(params![*id as i64, key, 0])?;
314            }
315            for id in deleted.iter() {
316                insert_stmt.execute(params![*id as i64, format!("_deleted:{}", id), 1])?;
317            }
318            drop(insert_stmt);
319
320            tx.execute(
321                "INSERT INTO hnsw_metadata (key, value) VALUES ('next_id', ?1)
322                 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
323                params![next_id.to_string()],
324            )?;
325            Ok(())
326        })?;
327
328        self.inner.keymap_dirty.store(false, Ordering::Release);
329        Ok(())
330    }
331
332    pub fn load_keymap(&self, conn: &rusqlite::Connection) -> Result<(), MemoryError> {
333        let table_exists: bool = conn
334            .query_row(
335                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='hnsw_keymap'",
336                [],
337                |row| row.get(0),
338            )
339            .unwrap_or(false);
340        if !table_exists {
341            tracing::warn!("hnsw_keymap table missing; HNSW keys will remain empty until rebuild");
342            return Ok(());
343        }
344
345        // CONVENTION EXCEPTION: O(1) lookup required for HNSW index
346        let mut key_to_id = HashMap::new();
347        // CONVENTION EXCEPTION: O(1) lookup required for HNSW index
348        let mut id_to_key = HashMap::new();
349        let mut deleted_ids = HashSet::new();
350
351        let mut stmt = conn.prepare("SELECT node_id, item_key, deleted FROM hnsw_keymap")?;
352        let rows = stmt.query_map([], |row| {
353            Ok((
354                row.get::<_, i64>(0)? as usize,
355                row.get::<_, String>(1)?,
356                row.get::<_, bool>(2)?,
357            ))
358        })?;
359
360        for row in rows {
361            let (node_id, item_key, deleted) = row?;
362            if node_id >= self.inner.next_id.load(Ordering::SeqCst) {
363                return Err(MemoryError::HnswError(format!(
364                    "hnsw_keymap node_id {node_id} is outside loaded HNSW sidecar bounds"
365                )));
366            }
367            if deleted {
368                deleted_ids.insert(node_id);
369            } else {
370                key_to_id.insert(item_key.clone(), node_id);
371                id_to_key.insert(node_id, item_key);
372            }
373        }
374
375        let next_id = conn
376            .query_row(
377                "SELECT value FROM hnsw_metadata WHERE key = 'next_id'",
378                [],
379                |row| row.get::<_, String>(0),
380            )
381            .ok()
382            .and_then(|value| value.parse::<usize>().ok())
383            .unwrap_or_else(|| self.inner.graph.get_nb_point());
384
385        *self
386            .inner
387            .key_to_id
388            .write()
389            .unwrap_or_else(|e| e.into_inner()) = key_to_id;
390        *self
391            .inner
392            .id_to_key
393            .write()
394            .unwrap_or_else(|e| e.into_inner()) = id_to_key;
395        *self
396            .inner
397            .deleted_ids
398            .write()
399            .unwrap_or_else(|e| e.into_inner()) = deleted_ids;
400        self.inner.next_id.store(next_id, Ordering::SeqCst);
401        self.inner.keymap_dirty.store(false, Ordering::Release);
402
403        Ok(())
404    }
405
406    fn insert_with_id(
407        &self,
408        key: Option<String>,
409        id: usize,
410        vector: &[f32],
411    ) -> Result<(), MemoryError> {
412        validate_dimensions(vector, self.inner.config.dimensions)?;
413
414        if let Some(key) = key {
415            self.inner.graph.insert((vector, id));
416
417            let mut key_to_id = self
418                .inner
419                .key_to_id
420                .write()
421                .unwrap_or_else(|e| e.into_inner());
422            let mut id_to_key = self
423                .inner
424                .id_to_key
425                .write()
426                .unwrap_or_else(|e| e.into_inner());
427
428            if let Some(old_id) = key_to_id.insert(key.clone(), id) {
429                id_to_key.remove(&old_id);
430                self.inner
431                    .deleted_ids
432                    .write()
433                    .unwrap_or_else(|e| e.into_inner())
434                    .insert(old_id);
435            }
436            id_to_key.insert(id, key);
437            self.inner.keymap_dirty.store(true, Ordering::Release);
438        } else {
439            self.inner.graph.insert((vector, id));
440        }
441        Ok(())
442    }
443}
444
445fn validate_dimensions(vector: &[f32], expected: usize) -> Result<(), MemoryError> {
446    if vector.len() != expected {
447        return Err(MemoryError::HnswError(format!(
448            "expected {} dimensions, got {}",
449            expected,
450            vector.len()
451        )));
452    }
453    // LIB-LOW-002: Reject NaN/infinity embeddings
454    if vector.iter().any(|v| !v.is_finite()) {
455        return Err(MemoryError::HnswError(
456            "embedding contains NaN or infinity values".into(),
457        ));
458    }
459    Ok(())
460}
461
462fn validate_graph_sidecar(graph_path: &Path) -> Result<(), MemoryError> {
463    let mut file = File::open(graph_path).map_err(|e| {
464        MemoryError::HnswError(format!("failed to open {}: {}", graph_path.display(), e))
465    })?;
466    let len = file.seek(std::io::SeekFrom::End(0)).map_err(|e| {
467        MemoryError::HnswError(format!("failed to inspect {}: {}", graph_path.display(), e))
468    })?;
469    if len == 0 {
470        return Err(MemoryError::HnswError(format!(
471            "empty HNSW graph sidecar: {}",
472            graph_path.display()
473        )));
474    }
475    Ok(())
476}
477
478fn load_vectors_from_sidecar(index: &HnswIndex, data_path: &Path) -> Result<usize, MemoryError> {
479    let mut file = File::open(data_path).map_err(|e| {
480        MemoryError::HnswError(format!("failed to open {}: {}", data_path.display(), e))
481    })?;
482
483    let mut u32_buf = [0u8; 4];
484    file.read_exact(&mut u32_buf).map_err(|e| {
485        MemoryError::HnswError(format!("failed to read HNSW sidecar header: {}", e))
486    })?;
487    if u32::from_le_bytes(u32_buf) != HNSW_DATA_MAGIC {
488        return Err(MemoryError::HnswError(
489            "invalid HNSW data file magic header".to_string(),
490        ));
491    }
492
493    let mut usize_buf = [0u8; std::mem::size_of::<usize>()];
494    file.read_exact(&mut usize_buf).map_err(|e| {
495        MemoryError::HnswError(format!("failed to read HNSW sidecar dimensions: {}", e))
496    })?;
497    let dims = usize::from_le_bytes(usize_buf);
498    if dims != index.inner.config.dimensions {
499        return Err(MemoryError::HnswError(format!(
500            "HNSW sidecar dimensions {} do not match configured {}",
501            dims, index.inner.config.dimensions
502        )));
503    }
504
505    let mut max_id = 0usize;
506
507    loop {
508        match file.read_exact(&mut u32_buf) {
509            Ok(()) => {}
510            Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => break,
511            Err(err) => {
512                return Err(MemoryError::HnswError(format!(
513                    "failed while reading HNSW sidecar entry header: {}",
514                    err
515                )))
516            }
517        }
518
519        if u32::from_le_bytes(u32_buf) != HNSW_DATA_MAGIC {
520            return Err(MemoryError::HnswError(
521                "invalid per-vector HNSW data magic".to_string(),
522            ));
523        }
524
525        let mut u64_buf = [0u8; 8];
526        file.read_exact(&mut u64_buf).map_err(|e| {
527            MemoryError::HnswError(format!("failed to read HNSW sidecar node id: {}", e))
528        })?;
529        let id = u64::from_le_bytes(u64_buf) as usize;
530
531        file.read_exact(&mut u64_buf).map_err(|e| {
532            MemoryError::HnswError(format!("failed to read HNSW sidecar vector length: {}", e))
533        })?;
534        let byte_len = u64::from_le_bytes(u64_buf) as usize;
535        let mut raw = vec![0u8; byte_len];
536        file.read_exact(&mut raw).map_err(|e| {
537            MemoryError::HnswError(format!("failed to read HNSW sidecar payload: {}", e))
538        })?;
539
540        let vector = db::bytes_to_embedding(&raw)?;
541        index.insert_with_id(None, id, &vector)?;
542        max_id = max_id.max(id);
543    }
544
545    Ok(max_id)
546}