Skip to main content

sqlrite/sql/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) approximate-nearest-neighbor
2//! index. Pure algorithm; no SQL integration in this module.
3//!
4//! HNSW is the industry-standard ANN algorithm for in-memory vector search:
5//! a multi-layer graph where each node lives at some randomly-assigned max
6//! layer; higher layers are sparser, layer 0 contains every node. Search
7//! starts at the entry point (the node at the current top layer), greedily
8//! descends layer-by-layer, then does a beam search at layer 0.
9//!
10//! ```text
11//!     layer 2:   [A] -- [E]                    sparse
12//!                 |       |
13//!     layer 1:   [A] -- [E] -- [G] -- [J]      mid
14//!                 |  /  |  \   |  \   |
15//!     layer 0:   [A,B,C,D,E,F,G,H,I,J,...]     dense (every node)
16//! ```
17//!
18//! ## What this module is responsible for
19//!
20//! - The graph (per-node, per-layer neighbor lists)
21//! - Layer assignment for new nodes (geometric distribution)
22//! - Insertion: greedy descent + beam search + neighbor pruning
23//! - Query: greedy descent + beam search at layer 0, return top-k
24//!
25//! ## What it is NOT responsible for (yet)
26//!
27//! - **Storing vectors.** The algorithm calls a `get_vec(node_id) -> &[f32]`
28//!   closure to fetch the vector for any node it touches. In Phase 7d.2
29//!   that closure will read from the SQL table holding the indexed
30//!   column; in tests it reads from an in-memory `Vec<Vec<f32>>`.
31//! - **Persistence.** The graph lives in `HashMap<i64, Node>` for now.
32//!   Phase 7d.3 wires it into the cell-encoded page format.
33//! - **DELETE / UPDATE.** Pre-existing nodes can't be removed today.
34//!   Soft-delete + lazy rebuild is the planned approach for 7d.2/7d.3.
35//!
36//! ## Parameters (per Phase 7 plan Q2 — fixed defaults)
37//!
38//! - `M = 16`              — max neighbors per node at layers > 0
39//! - `m_max0 = 32` (= 2·M) — max neighbors at layer 0
40//! - `ef_construction = 200` — beam width during INSERT
41//! - `ef_search = 50`      — default beam width during query
42//! - `m_l = 1/ln(M) ≈ 0.36`  — layer-assignment scale
43//!
44//! ## Invariants
45//!
46//! - Every `node.layers` Vec has length `node_max_layer + 1` for that node.
47//! - `node.layers[i]` contains node_ids of neighbors at layer i. Each
48//!   neighbor is itself a node in `nodes`; symmetrical (if A → B at layer i
49//!   then B → A at layer i, modulo pruning).
50//! - `entry_point` is `Some(id)` iff `nodes` is non-empty. The entry node
51//!   has the highest max-layer of any node currently in the graph.
52
53use std::cmp::Ordering;
54use std::collections::{BinaryHeap, HashMap, HashSet};
55
56/// Distance metric used by the HNSW index. Must match what the
57/// surrounding `vec_distance_*` SQL function would compute on the same
58/// pair of vectors — otherwise the index probe and the brute-force
59/// fallback would disagree on which rows are "nearest". See
60/// `src/sql/executor.rs`'s `vec_distance_l2` / `_cosine` / `_dot` for
61/// the canonical implementations.
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum DistanceMetric {
64    L2,
65    Cosine,
66    Dot,
67}
68
69impl DistanceMetric {
70    /// Parses the metric name from the `CREATE INDEX … WITH
71    /// (metric = '<name>')` clause. Case-insensitive. Returns `None`
72    /// for unknown values; the parser surfaces that as a user-visible
73    /// error so a typo doesn't silently fall back to L2.
74    pub fn from_sql_name(name: &str) -> Option<Self> {
75        match name.to_ascii_lowercase().as_str() {
76            "l2" | "euclidean" => Some(DistanceMetric::L2),
77            "cosine" => Some(DistanceMetric::Cosine),
78            "dot" | "inner_product" | "ip" => Some(DistanceMetric::Dot),
79            _ => None,
80        }
81    }
82
83    /// The canonical SQL-surface name for this metric. Used when
84    /// synthesizing CREATE INDEX SQL back into `sqlrite_master`.
85    pub fn sql_name(self) -> &'static str {
86        match self {
87            DistanceMetric::L2 => "l2",
88            DistanceMetric::Cosine => "cosine",
89            DistanceMetric::Dot => "dot",
90        }
91    }
92
93    /// The `vec_distance_*` SQL function whose result this metric
94    /// orders by. The optimizer's HNSW shortcut only fires when the
95    /// query's ORDER BY expression names this exact function.
96    pub fn matching_distance_fn(self) -> &'static str {
97        match self {
98            DistanceMetric::L2 => "vec_distance_l2",
99            DistanceMetric::Cosine => "vec_distance_cosine",
100            DistanceMetric::Dot => "vec_distance_dot",
101        }
102    }
103
104    /// Computes the configured distance between two equal-dimension
105    /// vectors. Returns `f32::INFINITY` for the cosine/zero-magnitude
106    /// edge case; HNSW treats infinity as "worst possible candidate" and
107    /// will prefer any finite alternative, which matches the SQL-level
108    /// behaviour where `vec_distance_cosine` errors but the optimizer's
109    /// fallback path simply skips the offending row.
110    pub fn compute(self, a: &[f32], b: &[f32]) -> f32 {
111        debug_assert_eq!(a.len(), b.len(), "vector dim mismatch in HNSW distance");
112        match self {
113            DistanceMetric::L2 => {
114                let mut sum = 0.0f32;
115                for i in 0..a.len() {
116                    let d = a[i] - b[i];
117                    sum += d * d;
118                }
119                sum.sqrt()
120            }
121            DistanceMetric::Cosine => {
122                let mut dot = 0.0f32;
123                let mut na = 0.0f32;
124                let mut nb = 0.0f32;
125                for i in 0..a.len() {
126                    dot += a[i] * b[i];
127                    na += a[i] * a[i];
128                    nb += b[i] * b[i];
129                }
130                let denom = (na * nb).sqrt();
131                if denom == 0.0 {
132                    f32::INFINITY
133                } else {
134                    1.0 - dot / denom
135                }
136            }
137            DistanceMetric::Dot => {
138                let mut dot = 0.0f32;
139                for i in 0..a.len() {
140                    dot += a[i] * b[i];
141                }
142                -dot
143            }
144        }
145    }
146}
147
148/// Per-node metadata: a list of neighbor IDs for each layer this node
149/// lives in. `layers[0]` is layer 0 (densest); `layers[layers.len() - 1]`
150/// is the highest layer this node reaches.
151#[derive(Debug, Clone, Default)]
152pub struct Node {
153    /// Indexed by layer (0 = dense). `layers[i]` is the neighbor list
154    /// for this node at layer i. Always sorted-by-distance is *not* a
155    /// guaranteed invariant — pruning maintains it after each
156    /// modification, but during insert we may briefly hold an
157    /// unsorted set.
158    pub layers: Vec<Vec<i64>>,
159}
160
161impl Node {
162    /// Maximum layer this node reaches. Equals `layers.len() - 1`.
163    pub fn max_layer(&self) -> usize {
164        self.layers.len() - 1
165    }
166}
167
168/// HNSW algorithm parameters. Phase 7 ships fixed defaults (Q2 in the
169/// plan); this struct is `Clone + Copy` so callers wanting to fork an
170/// experimental tuning can do so without touching the index itself.
171#[derive(Debug, Clone, Copy)]
172pub struct HnswParams {
173    pub m: usize,
174    pub m_max0: usize,
175    pub ef_construction: usize,
176    pub ef_search: usize,
177    pub m_l: f32,
178}
179
180impl Default for HnswParams {
181    fn default() -> Self {
182        let m = 16;
183        Self {
184            m,
185            m_max0: 2 * m,
186            ef_construction: 200,
187            ef_search: 50,
188            m_l: 1.0 / (m as f32).ln(),
189        }
190    }
191}
192
193/// In-memory HNSW graph. See module docs for the model.
194#[derive(Debug, Clone)]
195pub struct HnswIndex {
196    pub params: HnswParams,
197    pub distance: DistanceMetric,
198    /// Node id of the entry point. `None` iff the index is empty.
199    /// At all times this is the id of the node with the highest
200    /// max-layer; if multiple nodes tie for the top layer, the
201    /// most-recently-promoted one wins.
202    pub entry_point: Option<i64>,
203    /// Highest layer currently populated. 0 when the index has at
204    /// most one node, grows as new nodes get assigned higher layers.
205    pub top_layer: usize,
206    /// Node id → its per-layer neighbor lists.
207    pub nodes: HashMap<i64, Node>,
208    /// xorshift64 RNG state for layer assignment. Seeded explicitly via
209    /// `new` so tests can pin a known sequence.
210    rng_state: u64,
211}
212
213impl HnswIndex {
214    /// Builds an empty HNSW index with default parameters and the given
215    /// distance metric + RNG seed. A seed of 0 is mapped to a small
216    /// nonzero constant — xorshift gets stuck at zero.
217    pub fn new(distance: DistanceMetric, seed: u64) -> Self {
218        let seed = if seed == 0 { 0x9E3779B97F4A7C15 } else { seed };
219        Self {
220            params: HnswParams::default(),
221            distance,
222            entry_point: None,
223            top_layer: 0,
224            nodes: HashMap::new(),
225            rng_state: seed,
226        }
227    }
228
229    /// True if no nodes have been inserted yet.
230    pub fn is_empty(&self) -> bool {
231        self.nodes.is_empty()
232    }
233
234    /// Number of nodes currently in the index.
235    pub fn len(&self) -> usize {
236        self.nodes.len()
237    }
238
239    /// Phase 7d.3 — produces (node_id, layers) pairs in ascending node_id
240    /// order, suitable for serializing the graph to disk via the
241    /// `HnswNodeCell` wire format. The graph's metadata
242    /// (entry_point + top_layer) is recoverable from the nodes alone:
243    /// top_layer = max(max_layer); entry_point = any node at top_layer.
244    /// So we don't ship a separate metadata cell.
245    pub fn serialize_nodes(&self) -> Vec<(i64, Vec<Vec<i64>>)> {
246        let mut out: Vec<(i64, Vec<Vec<i64>>)> = self
247            .nodes
248            .iter()
249            .map(|(id, n)| (*id, n.layers.clone()))
250            .collect();
251        out.sort_by_key(|(id, _)| *id);
252        out
253    }
254
255    /// Phase 7d.3 — rebuilds an HnswIndex from a stream of (node_id, layers)
256    /// pairs as produced by `serialize_nodes` and round-tripped through
257    /// `HnswNodeCell` encode/decode. The rebuilt index has the same nodes,
258    /// same neighbor lists, same entry_point + top_layer as the source.
259    /// `seed` is fresh; the deserialized index is never inserted into via
260    /// the algorithmic `insert` path so the seed only matters if a caller
261    /// later calls `insert` after deserializing (then it controls layer
262    /// assignment for the appended node).
263    pub fn from_persisted_nodes<I>(distance: DistanceMetric, seed: u64, nodes: I) -> Self
264    where
265        I: IntoIterator<Item = (i64, Vec<Vec<i64>>)>,
266    {
267        let mut idx = Self::new(distance, seed);
268        let mut top_layer = 0usize;
269        let mut entry_point: Option<i64> = None;
270        for (id, layers) in nodes {
271            let max_layer = layers.len().saturating_sub(1);
272            if max_layer > top_layer || entry_point.is_none() {
273                top_layer = max_layer;
274                entry_point = Some(id);
275            }
276            idx.nodes.insert(id, Node { layers });
277        }
278        idx.top_layer = top_layer;
279        idx.entry_point = entry_point;
280        idx
281    }
282
283    /// Inserts a node into the graph. The node id must be unique;
284    /// re-inserting an existing id is a no-op (returns without error).
285    /// `vec` is the new node's vector; `get_vec` looks up the vector
286    /// for any other node id the algorithm touches.
287    pub fn insert<F>(&mut self, node_id: i64, vec: &[f32], get_vec: F)
288    where
289        F: Fn(i64) -> Vec<f32>,
290    {
291        if self.nodes.contains_key(&node_id) {
292            return;
293        }
294
295        // First node: trivial case. Becomes entry point at layer 0.
296        if self.is_empty() {
297            self.nodes.insert(
298                node_id,
299                Node {
300                    layers: vec![Vec::new()],
301                },
302            );
303            self.entry_point = Some(node_id);
304            self.top_layer = 0;
305            return;
306        }
307
308        // Pick a layer for this new node.
309        let target_layer = self.pick_layer();
310
311        // Pre-allocate the new node's layer lists (empty for now;
312        // populated below).
313        let new_node = Node {
314            layers: vec![Vec::new(); target_layer + 1],
315        };
316        self.nodes.insert(node_id, new_node);
317
318        // Greedy descent from top down to (target_layer + 1) — at each
319        // layer above our target, advance the entry point to the
320        // single closest node. We don't add edges at these layers
321        // because the new node doesn't live there.
322        let mut entry = self.entry_point.expect("non-empty index has entry point");
323        for layer in (target_layer + 1..=self.top_layer).rev() {
324            let nearest = self.search_layer(vec, &[entry], 1, layer, &get_vec);
325            if let Some((_, id)) = nearest.into_iter().next() {
326                entry = id;
327            }
328        }
329
330        // Beam search + connect at each layer the new node lives in.
331        // We work top-down; the entry point for each layer is the best
332        // candidate found at the layer above.
333        let mut entries = vec![entry];
334        for layer in (0..=target_layer).rev() {
335            let candidates =
336                self.search_layer(vec, &entries, self.params.ef_construction, layer, &get_vec);
337
338            // Pick up to M neighbors from candidates (M_max0 at layer 0
339            // since we allow more connections at the dense layer).
340            let m_max = if layer == 0 {
341                self.params.m_max0
342            } else {
343                self.params.m
344            };
345            let neighbors: Vec<i64> = candidates
346                .iter()
347                .take(self.params.m)
348                .map(|(_, id)| *id)
349                .collect();
350
351            // Wire up the bidirectional edges.
352            self.nodes.get_mut(&node_id).expect("just inserted").layers[layer] = neighbors.clone();
353
354            for &nb in &neighbors {
355                let nb_layers = &mut self.nodes.get_mut(&nb).expect("neighbor must exist").layers;
356                if layer >= nb_layers.len() {
357                    // Neighbor doesn't actually live at this layer — shouldn't
358                    // happen because search_layer only returns nodes at this
359                    // layer, but defend against it.
360                    continue;
361                }
362                nb_layers[layer].push(node_id);
363
364                // Prune the neighbor's edge list if it's now over its M_max
365                // budget. Pruning policy: keep the closest M_max nodes
366                // by distance. (Distance recomputed; no precomputed values.)
367                if nb_layers[layer].len() > m_max {
368                    let nb_vec = get_vec(nb);
369                    let mut by_dist: Vec<(f32, i64)> = nb_layers[layer]
370                        .iter()
371                        .map(|id| (self.distance.compute(&nb_vec, &get_vec(*id)), *id))
372                        .collect();
373                    by_dist
374                        .sort_by(|(da, _), (db, _)| da.partial_cmp(db).unwrap_or(Ordering::Equal));
375                    by_dist.truncate(m_max);
376                    nb_layers[layer] = by_dist.into_iter().map(|(_, id)| id).collect();
377                }
378            }
379
380            // Carry the candidate set forward as entry points for the
381            // next (lower) layer.
382            entries = candidates.into_iter().map(|(_, id)| id).collect();
383        }
384
385        // If this new node lives higher than the current top, promote it.
386        if target_layer > self.top_layer {
387            self.top_layer = target_layer;
388            self.entry_point = Some(node_id);
389        }
390    }
391
392    /// Returns the k nearest node ids to `query`, in distance-ascending
393    /// order (closest first). Empty index returns an empty Vec.
394    pub fn search<F>(&self, query: &[f32], k: usize, get_vec: F) -> Vec<i64>
395    where
396        F: Fn(i64) -> Vec<f32>,
397    {
398        if self.is_empty() || k == 0 {
399            return Vec::new();
400        }
401
402        // Greedy descent from the top down to layer 1.
403        let mut entry = self.entry_point.expect("non-empty index has entry point");
404        for layer in (1..=self.top_layer).rev() {
405            let nearest = self.search_layer(query, &[entry], 1, layer, &get_vec);
406            if let Some((_, id)) = nearest.into_iter().next() {
407                entry = id;
408            }
409        }
410
411        // Beam search at layer 0 with width = max(ef_search, k).
412        let ef = self.params.ef_search.max(k);
413        let candidates = self.search_layer(query, &[entry], ef, 0, &get_vec);
414
415        candidates.into_iter().take(k).map(|(_, id)| id).collect()
416    }
417
418    /// Runs a beam search at one layer starting from `entries`, returning
419    /// the top-`ef` nearest nodes to `query` found, sorted by distance
420    /// ascending.
421    ///
422    /// This is the workhorse of both insert and search. The two priority
423    /// queues — "candidates" (nodes still to expand) and "results"
424    /// (current best ef found) — terminate when the closest unexpanded
425    /// candidate is farther than the worst kept result.
426    fn search_layer<F>(
427        &self,
428        query: &[f32],
429        entries: &[i64],
430        ef: usize,
431        layer: usize,
432        get_vec: &F,
433    ) -> Vec<(f32, i64)>
434    where
435        F: Fn(i64) -> Vec<f32>,
436    {
437        let mut visited: HashSet<i64> = HashSet::with_capacity(ef * 2);
438        // candidates: min-heap of (distance, id) — pop closest first.
439        let mut candidates: BinaryHeap<MinHeapItem> = BinaryHeap::with_capacity(ef * 2);
440        // results: max-heap of (distance, id) — top is the worst kept.
441        let mut results: BinaryHeap<MaxHeapItem> = BinaryHeap::with_capacity(ef);
442
443        for &id in entries {
444            if !visited.insert(id) {
445                continue;
446            }
447            let d = self.distance.compute(query, &get_vec(id));
448            candidates.push(MinHeapItem { dist: d, id });
449            results.push(MaxHeapItem { dist: d, id });
450        }
451
452        while let Some(MinHeapItem {
453            dist: c_dist,
454            id: c_id,
455        }) = candidates.pop()
456        {
457            // If the closest unexpanded candidate is worse than the
458            // worst kept result, no further expansion can improve the
459            // result set. Bail.
460            if let Some(worst) = results.peek() {
461                if results.len() >= ef && c_dist > worst.dist {
462                    break;
463                }
464            }
465
466            // Expand: visit each neighbor of c_id at this layer.
467            let neighbors = self
468                .nodes
469                .get(&c_id)
470                .and_then(|n| n.layers.get(layer))
471                .cloned()
472                .unwrap_or_default();
473            for nb in neighbors {
474                if !visited.insert(nb) {
475                    continue;
476                }
477                let d = self.distance.compute(query, &get_vec(nb));
478                let admit = if results.len() < ef {
479                    true
480                } else {
481                    d < results.peek().unwrap().dist
482                };
483                if admit {
484                    candidates.push(MinHeapItem { dist: d, id: nb });
485                    results.push(MaxHeapItem { dist: d, id: nb });
486                    if results.len() > ef {
487                        results.pop();
488                    }
489                }
490            }
491        }
492
493        // Drain results into a sorted vec. results is a max-heap, so
494        // popping gives descending order; reverse for ascending.
495        let mut out: Vec<(f32, i64)> = Vec::with_capacity(results.len());
496        while let Some(item) = results.pop() {
497            out.push((item.dist, item.id));
498        }
499        out.reverse();
500        out
501    }
502
503    /// Picks a layer for a new node using the standard HNSW geometric
504    /// distribution: `L = floor(-ln(uniform) * m_l)`. With M=16, mL ≈ 0.36,
505    /// so:
506    ///   - P(L=0) ≈ 1 - 1/M = 15/16
507    ///   - P(L=1) ≈ 1/16 - 1/256
508    ///   - P(L=2) ≈ 1/256 - …
509    /// i.e., most new nodes live only at layer 0; a few percolate up.
510    fn pick_layer(&mut self) -> usize {
511        let u = self.next_uniform().max(1e-6); // guard log(0)
512        let layer = (-u.ln() * self.params.m_l).floor() as usize;
513        // Cap at top_layer + 1 to keep the graph from sprouting empty
514        // layers above the current top — matches the original HNSW
515        // paper's recommendation.
516        layer.min(self.top_layer + 1)
517    }
518
519    /// Pulls a uniform-on-(0, 1] f32 from the internal xorshift state.
520    /// Top 24 bits of the next u64, divided by 2^24 — gives 24-bit
521    /// uniform precision, plenty for layer assignment.
522    fn next_uniform(&mut self) -> f32 {
523        let mut x = self.rng_state;
524        x ^= x << 13;
525        x ^= x >> 7;
526        x ^= x << 17;
527        self.rng_state = x;
528        ((x >> 40) as u32) as f32 / (1u32 << 24) as f32
529    }
530}
531
532// -----------------------------------------------------------------
533// Heap items
534//
535// Rust's BinaryHeap is a max-heap that uses Ord. f32 doesn't impl Ord
536// (NaN), so we wrap (distance, id) pairs and provide custom Ord that
537// uses partial_cmp with NaN treated as Greater (NaN sorts as worst).
538//
539// MinHeapItem inverts the comparison so BinaryHeap<MinHeapItem> behaves
540// as a min-heap — top is the smallest distance, popping gives ascending
541// order.
542//
543// MaxHeapItem uses the natural ordering — top is the largest distance.
544
545#[derive(Debug, Clone, Copy)]
546struct MinHeapItem {
547    dist: f32,
548    id: i64,
549}
550
551impl PartialEq for MinHeapItem {
552    fn eq(&self, other: &Self) -> bool {
553        self.dist == other.dist && self.id == other.id
554    }
555}
556impl Eq for MinHeapItem {}
557impl PartialOrd for MinHeapItem {
558    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
559        Some(self.cmp(other))
560    }
561}
562impl Ord for MinHeapItem {
563    fn cmp(&self, other: &Self) -> Ordering {
564        // Reverse so smallest distance bubbles to top.
565        other
566            .dist
567            .partial_cmp(&self.dist)
568            .unwrap_or(Ordering::Equal)
569            .then(other.id.cmp(&self.id))
570    }
571}
572
573#[derive(Debug, Clone, Copy)]
574struct MaxHeapItem {
575    dist: f32,
576    id: i64,
577}
578
579impl PartialEq for MaxHeapItem {
580    fn eq(&self, other: &Self) -> bool {
581        self.dist == other.dist && self.id == other.id
582    }
583}
584impl Eq for MaxHeapItem {}
585impl PartialOrd for MaxHeapItem {
586    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
587        Some(self.cmp(other))
588    }
589}
590impl Ord for MaxHeapItem {
591    fn cmp(&self, other: &Self) -> Ordering {
592        // Natural so largest distance bubbles to top.
593        self.dist
594            .partial_cmp(&other.dist)
595            .unwrap_or(Ordering::Equal)
596            .then(self.id.cmp(&other.id))
597    }
598}
599
600// -----------------------------------------------------------------
601// Tests
602// -----------------------------------------------------------------
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607
608    /// Deterministic xorshift to generate test vectors.
609    fn random_vec(state: &mut u64, dim: usize) -> Vec<f32> {
610        (0..dim)
611            .map(|_| {
612                let mut x = *state;
613                x ^= x << 13;
614                x ^= x >> 7;
615                x ^= x << 17;
616                *state = x;
617                ((x >> 40) as u32) as f32 / (1u32 << 24) as f32
618            })
619            .collect()
620    }
621
622    /// Brute-force nearest-neighbors baseline for recall comparison.
623    fn brute_force_topk(
624        vectors: &[Vec<f32>],
625        query: &[f32],
626        k: usize,
627        metric: DistanceMetric,
628    ) -> Vec<i64> {
629        let mut by_dist: Vec<(f32, i64)> = vectors
630            .iter()
631            .enumerate()
632            .map(|(i, v)| (metric.compute(query, v), i as i64))
633            .collect();
634        by_dist.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap_or(Ordering::Equal));
635        by_dist.into_iter().take(k).map(|(_, id)| id).collect()
636    }
637
638    /// recall@k — fraction of the brute-force top-k that the HNSW
639    /// search also returned (in any order).
640    fn recall_at_k(hnsw_result: &[i64], baseline: &[i64]) -> f32 {
641        let baseline_set: HashSet<i64> = baseline.iter().copied().collect();
642        let hits = hnsw_result
643            .iter()
644            .filter(|id| baseline_set.contains(id))
645            .count();
646        hits as f32 / baseline.len() as f32
647    }
648
649    #[test]
650    fn empty_index_returns_empty_search() {
651        let idx = HnswIndex::new(DistanceMetric::L2, 42);
652        let vectors: Vec<Vec<f32>> = vec![];
653        let result = idx.search(&[0.0; 4], 5, |id| vectors[id as usize].clone());
654        assert!(result.is_empty());
655    }
656
657    #[test]
658    fn single_node_returns_only_itself() {
659        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
660        let v0 = vec![1.0, 2.0, 3.0];
661        let vectors = vec![v0.clone()];
662        idx.insert(0, &v0, |id| vectors[id as usize].clone());
663        let result = idx.search(&[0.0; 3], 5, |id| vectors[id as usize].clone());
664        assert_eq!(result, vec![0]);
665    }
666
667    #[test]
668    fn duplicate_insert_is_noop() {
669        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
670        let v0 = vec![1.0, 2.0];
671        let vectors = vec![v0.clone()];
672        idx.insert(0, &v0, |id| vectors[id as usize].clone());
673        idx.insert(0, &v0, |id| vectors[id as usize].clone());
674        assert_eq!(idx.len(), 1);
675    }
676
677    #[test]
678    fn k_zero_returns_empty() {
679        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
680        let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
681        for (i, v) in vectors.iter().enumerate() {
682            idx.insert(i as i64, v, |id| vectors[id as usize].clone());
683        }
684        let result = idx.search(&[0.5, 0.5], 0, |id| vectors[id as usize].clone());
685        assert!(result.is_empty());
686    }
687
688    #[test]
689    fn small_graph_finds_exact_nearest() {
690        // 5 well-separated points in 2D — HNSW should find the exact
691        // nearest with no recall loss for k=1 and k=3.
692        let vectors: Vec<Vec<f32>> = vec![
693            vec![0.0, 0.0],
694            vec![10.0, 0.0],
695            vec![0.0, 10.0],
696            vec![10.0, 10.0],
697            vec![5.0, 5.0],
698        ];
699        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
700        for (i, v) in vectors.iter().enumerate() {
701            idx.insert(i as i64, v, |id| vectors[id as usize].clone());
702        }
703
704        // Query at (1, 1): nearest is (0, 0).
705        let result = idx.search(&[1.0, 1.0], 1, |id| vectors[id as usize].clone());
706        assert_eq!(result, vec![0]);
707
708        // Query at (5.5, 5.5): top-3 should be id=4 (5,5), then any
709        // two of the corners at distance ~7.78.
710        let result = idx.search(&[5.5, 5.5], 3, |id| vectors[id as usize].clone());
711        assert_eq!(result.len(), 3);
712        assert_eq!(result[0], 4, "closest to (5.5,5.5) should be id=4");
713    }
714
715    #[test]
716    fn recall_at_10_is_high_on_random_vectors_l2() {
717        // Standard recall test: 1000 random vectors in 8D, query for
718        // top-10 with HNSW, compare to brute-force ground truth.
719        // Modern HNSW papers target recall@10 > 0.95; we should clear
720        // that comfortably on this small benchmark.
721        let mut state: u64 = 0xDEADBEEF;
722        let dim = 8;
723        let n = 1000;
724        let queries = 20;
725        let k = 10;
726
727        let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
728
729        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
730        for (i, v) in vectors.iter().enumerate() {
731            idx.insert(i as i64, v, |id| vectors[id as usize].clone());
732        }
733
734        let mut total_recall = 0.0f32;
735        for _ in 0..queries {
736            let q = random_vec(&mut state, dim);
737            let hnsw_top = idx.search(&q, k, |id| vectors[id as usize].clone());
738            let baseline = brute_force_topk(&vectors, &q, k, DistanceMetric::L2);
739            total_recall += recall_at_k(&hnsw_top, &baseline);
740        }
741        let avg_recall = total_recall / queries as f32;
742        assert!(
743            avg_recall >= 0.95,
744            "recall@{k} dropped below 0.95: avg={avg_recall:.3}"
745        );
746    }
747
748    #[test]
749    fn recall_at_10_is_high_on_random_vectors_cosine() {
750        // Same shape as the L2 test but with cosine distance, to
751        // exercise the alternative metric through the same pipeline.
752        let mut state: u64 = 0xC0FFEE;
753        let dim = 16;
754        let n = 500;
755        let queries = 20;
756        let k = 10;
757
758        let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
759
760        let mut idx = HnswIndex::new(DistanceMetric::Cosine, 42);
761        for (i, v) in vectors.iter().enumerate() {
762            idx.insert(i as i64, v, |id| vectors[id as usize].clone());
763        }
764
765        let mut total_recall = 0.0f32;
766        for _ in 0..queries {
767            let q = random_vec(&mut state, dim);
768            let hnsw_top = idx.search(&q, k, |id| vectors[id as usize].clone());
769            let baseline = brute_force_topk(&vectors, &q, k, DistanceMetric::Cosine);
770            total_recall += recall_at_k(&hnsw_top, &baseline);
771        }
772        let avg_recall = total_recall / queries as f32;
773        assert!(
774            avg_recall >= 0.95,
775            "cosine recall@{k} dropped below 0.95: avg={avg_recall:.3}"
776        );
777    }
778
779    #[test]
780    fn entry_point_promotes_when_higher_layer_node_inserted() {
781        // The graph's entry point should always be a node at the
782        // current top layer. Insert two nodes; if the second lands at
783        // a higher layer, it becomes the entry point.
784        // We can't easily force a particular layer (it's randomized),
785        // so check the invariant: after every insert, the entry node's
786        // max_layer == top_layer.
787        let mut state: u64 = 0xABCDEF;
788        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
789        let dim = 4;
790        let mut vectors: Vec<Vec<f32>> = Vec::new();
791        for i in 0..50 {
792            vectors.push(random_vec(&mut state, dim));
793            let v = vectors[i].clone();
794            idx.insert(i as i64, &v, |id| vectors[id as usize].clone());
795
796            // Check invariant.
797            let entry = idx.entry_point.expect("non-empty");
798            let entry_max = idx.nodes[&entry].max_layer();
799            assert_eq!(
800                entry_max, idx.top_layer,
801                "entry-point invariant broken at step {i}: entry {entry} has max_layer {entry_max}, top_layer is {}",
802                idx.top_layer
803            );
804        }
805    }
806
807    #[test]
808    fn neighbor_lists_respect_m_max() {
809        // After inserting 200 points with M=16 (so M_max0 = 32), no
810        // node should have more than 32 neighbors at layer 0 or more
811        // than 16 at any higher layer.
812        let mut state: u64 = 0x123456;
813        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
814        let dim = 4;
815        let mut vectors: Vec<Vec<f32>> = Vec::new();
816        for i in 0..200 {
817            vectors.push(random_vec(&mut state, dim));
818            let v = vectors[i].clone();
819            idx.insert(i as i64, &v, |id| vectors[id as usize].clone());
820        }
821
822        for (id, node) in &idx.nodes {
823            for (layer, neighbors) in node.layers.iter().enumerate() {
824                let cap = if layer == 0 {
825                    idx.params.m_max0
826                } else {
827                    idx.params.m
828                };
829                assert!(
830                    neighbors.len() <= cap,
831                    "node {id} layer {layer} has {} > cap {cap}",
832                    neighbors.len()
833                );
834            }
835        }
836    }
837
838    #[test]
839    fn deterministic_with_fixed_seed() {
840        // Same seed + same insert order → same graph topology.
841        // Catches accidental sources of nondeterminism (HashMap
842        // iteration order, etc.).
843        let mut state: u64 = 0x999;
844        let dim = 4;
845        let n = 50;
846        let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
847
848        let mut idx_a = HnswIndex::new(DistanceMetric::L2, 42);
849        let mut idx_b = HnswIndex::new(DistanceMetric::L2, 42);
850        for (i, v) in vectors.iter().enumerate() {
851            idx_a.insert(i as i64, v, |id| vectors[id as usize].clone());
852            idx_b.insert(i as i64, v, |id| vectors[id as usize].clone());
853        }
854
855        // Same top layer.
856        assert_eq!(idx_a.top_layer, idx_b.top_layer);
857        // Same entry point.
858        assert_eq!(idx_a.entry_point, idx_b.entry_point);
859        // Same node count and same per-node max-layer for every id.
860        // (Neighbor list contents may differ trivially if HashMap
861        // iteration sneaked in; if this fails, fix the source first.)
862        assert_eq!(idx_a.nodes.len(), idx_b.nodes.len());
863        for (id, node_a) in &idx_a.nodes {
864            let node_b = idx_b.nodes.get(id).expect("missing id");
865            assert_eq!(node_a.max_layer(), node_b.max_layer(), "id={id}");
866        }
867    }
868}