1use crate::error::{IndexError, IndexResult};
27use crate::metric::Metric;
28use crate::PointId;
29use alloc::collections::BinaryHeap;
30use alloc::vec::Vec;
31use core::cmp::Ordering;
32use rand::Rng;
33use std::collections::{HashMap, HashSet};
34
35#[derive(Debug, Clone, Copy)]
41pub struct HnswConfig {
42 pub m: usize,
44 pub m_max0: usize,
47 pub ef_construction: usize,
49 pub ef_search: usize,
51 pub level_lambda: f32,
53}
54
55impl Default for HnswConfig {
56 fn default() -> Self {
57 let m = 16;
58 Self {
59 m,
60 m_max0: 2 * m,
61 ef_construction: 200,
62 ef_search: 50,
63 level_lambda: 1.0 / (m as f32).ln(),
64 }
65 }
66}
67
68impl HnswConfig {
69 fn validate(&self) -> IndexResult<()> {
70 if self.m == 0 {
71 return Err(IndexError::InvalidConfig("m must be > 0"));
72 }
73 if self.ef_construction < self.m {
74 return Err(IndexError::InvalidConfig("ef_construction must be >= m"));
75 }
76 if self.ef_search == 0 {
77 return Err(IndexError::InvalidConfig("ef_search must be > 0"));
78 }
79 if !self.level_lambda.is_finite() || self.level_lambda <= 0.0 {
80 return Err(IndexError::InvalidConfig(
81 "level_lambda must be finite and positive",
82 ));
83 }
84 Ok(())
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq)]
93pub struct Neighbor {
94 pub id: PointId,
96 pub distance: f32,
98}
99
100impl Eq for Neighbor {}
101
102impl Ord for Neighbor {
103 fn cmp(&self, other: &Self) -> Ordering {
104 self.distance
106 .partial_cmp(&other.distance)
107 .unwrap_or(Ordering::Equal)
108 .then_with(|| self.id.cmp(&other.id))
109 }
110}
111
112impl PartialOrd for Neighbor {
113 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
114 Some(self.cmp(other))
115 }
116}
117
118#[derive(Debug, Clone, Copy, PartialEq)]
121struct MaxHeapEntry(Neighbor);
122impl Eq for MaxHeapEntry {}
123impl Ord for MaxHeapEntry {
124 fn cmp(&self, other: &Self) -> Ordering {
125 self.0.cmp(&other.0)
126 }
127}
128impl PartialOrd for MaxHeapEntry {
129 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
130 Some(self.cmp(other))
131 }
132}
133
134#[derive(Debug, Clone, Copy, PartialEq)]
137struct MinHeapEntry(Neighbor);
138impl Eq for MinHeapEntry {}
139impl Ord for MinHeapEntry {
140 fn cmp(&self, other: &Self) -> Ordering {
141 other.0.cmp(&self.0)
143 }
144}
145impl PartialOrd for MinHeapEntry {
146 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
147 Some(self.cmp(other))
148 }
149}
150
151struct Node<P> {
156 point: P,
157 neighbors: Vec<Vec<PointId>>,
158}
159
160pub struct HnswIndex<P, M>
165where
166 M: Metric<Point = P>,
167{
168 config: HnswConfig,
169 metric: M,
170 nodes: HashMap<PointId, Node<P>>,
171 entry_point: Option<(PointId, usize)>, dim: Option<usize>,
173}
174
175impl<P, M> HnswIndex<P, M>
176where
177 M: Metric<Point = P>,
178{
179 pub fn new(config: HnswConfig, metric: M) -> IndexResult<Self> {
185 config.validate()?;
186 Ok(Self {
187 config,
188 metric,
189 nodes: HashMap::new(),
190 entry_point: None,
191 dim: None,
192 })
193 }
194
195 pub fn len(&self) -> usize {
197 self.nodes.len()
198 }
199
200 pub fn is_empty(&self) -> bool {
202 self.nodes.is_empty()
203 }
204
205 pub fn get(&self, id: PointId) -> Option<&P> {
207 self.nodes.get(&id).map(|n| &n.point)
208 }
209
210 pub fn insert(&mut self, id: PointId, point: P) -> IndexResult<()> {
215 if self.nodes.contains_key(&id) {
216 return Err(IndexError::DuplicateId(id));
217 }
218
219 let point_dim = self.metric.dim(&point);
220 match self.dim {
221 None => self.dim = Some(point_dim),
222 Some(d) if d != point_dim => {
223 return Err(IndexError::DimensionMismatch {
224 expected: d,
225 actual: point_dim,
226 });
227 }
228 _ => {}
229 }
230
231 let level = self.random_level(&mut rand::thread_rng());
232
233 let neighbors = (0..=level)
235 .map(|lvl| {
236 let cap = if lvl == 0 {
237 self.config.m_max0
238 } else {
239 self.config.m
240 };
241 Vec::with_capacity(cap)
242 })
243 .collect();
244
245 let node = Node { point, neighbors };
246 self.nodes.insert(id, node);
247
248 let Some((entry_id, entry_level)) = self.entry_point else {
250 self.entry_point = Some((id, level));
251 return Ok(());
252 };
253
254 let mut nearest = entry_id;
257 for lvl in ((level + 1)..=entry_level).rev() {
258 nearest = self.greedy_search_one_level(id, nearest, lvl);
259 }
260
261 for lvl in (0..=level.min(entry_level)).rev() {
264 let candidates = self.search_layer(id, &[nearest], lvl, self.config.ef_construction);
265 let m_at_level = if lvl == 0 {
266 self.config.m_max0
267 } else {
268 self.config.m
269 };
270 let selected = self.select_neighbors_heuristic(candidates, m_at_level, true);
271
272 let new_neighbors_at_level: Vec<PointId> = selected.iter().map(|n| n.id).collect();
275
276 self.nodes.get_mut(&id).unwrap().neighbors[lvl] = new_neighbors_at_level.clone();
279
280 for neighbor in &new_neighbors_at_level {
281 self.add_back_edge(*neighbor, id, lvl);
282 }
283
284 if let Some(closest) = selected.first() {
286 nearest = closest.id;
287 }
288 }
289
290 if level > entry_level {
292 self.entry_point = Some((id, level));
293 }
294
295 Ok(())
296 }
297
298 pub fn search(&self, query: &P, k: usize) -> Vec<Neighbor> {
303 self.search_with_ef(query, k, self.config.ef_search)
304 }
305
306 pub fn search_with_ef(&self, query: &P, k: usize, ef: usize) -> Vec<Neighbor> {
311 let Some((entry_id, entry_level)) = self.entry_point else {
312 return Vec::new();
313 };
314 let ef = ef.max(k);
315
316 let mut nearest_id = entry_id;
318 for lvl in (1..=entry_level).rev() {
319 nearest_id = self.greedy_search_one_level_query(query, nearest_id, lvl);
320 }
321
322 let mut found = self.search_layer_query(query, &[nearest_id], 0, ef);
324 found.sort();
325 found.truncate(k);
326 found
327 }
328
329 fn random_level<R: Rng>(&self, rng: &mut R) -> usize {
334 let r: f32 = rng.gen_range(f32::MIN_POSITIVE..1.0);
336 (-r.ln() * self.config.level_lambda).floor() as usize
337 }
338
339 fn greedy_search_one_level(&self, query_id: PointId, entry: PointId, level: usize) -> PointId {
342 let query = &self.nodes[&query_id].point;
343 self.greedy_search_one_level_query(query, entry, level)
344 }
345
346 fn greedy_search_one_level_query(&self, query: &P, entry: PointId, level: usize) -> PointId {
348 let mut current = entry;
349 let mut current_dist = self.metric.distance(query, &self.nodes[&entry].point);
350 loop {
351 let mut improved = false;
352 let neighbors_at_level = self.nodes[¤t]
355 .neighbors
356 .get(level)
357 .map(Vec::as_slice)
358 .unwrap_or(&[]);
359 for &nbr in neighbors_at_level {
360 let d = self.metric.distance(query, &self.nodes[&nbr].point);
361 if d < current_dist {
362 current_dist = d;
363 current = nbr;
364 improved = true;
365 }
366 }
367 if !improved {
368 return current;
369 }
370 }
371 }
372
373 fn search_layer(
376 &self,
377 query_id: PointId,
378 entry_points: &[PointId],
379 level: usize,
380 ef: usize,
381 ) -> Vec<Neighbor> {
382 let query = &self.nodes[&query_id].point;
383 self.search_layer_query_with_exclude(query, entry_points, level, ef, Some(query_id))
385 }
386
387 fn search_layer_query(
389 &self,
390 query: &P,
391 entry_points: &[PointId],
392 level: usize,
393 ef: usize,
394 ) -> Vec<Neighbor> {
395 self.search_layer_query_with_exclude(query, entry_points, level, ef, None)
396 }
397
398 fn search_layer_query_with_exclude(
399 &self,
400 query: &P,
401 entry_points: &[PointId],
402 level: usize,
403 ef: usize,
404 exclude: Option<PointId>,
405 ) -> Vec<Neighbor> {
406 let mut visited: HashSet<PointId> = HashSet::with_capacity(ef * 2);
407 let mut frontier: BinaryHeap<MinHeapEntry> = BinaryHeap::new(); let mut results: BinaryHeap<MaxHeapEntry> = BinaryHeap::new(); for &ep in entry_points {
411 if Some(ep) == exclude {
412 continue;
413 }
414 if !visited.insert(ep) {
415 continue;
416 }
417 let d = self.metric.distance(query, &self.nodes[&ep].point);
418 let n = Neighbor {
419 id: ep,
420 distance: d,
421 };
422 frontier.push(MinHeapEntry(n));
423 results.push(MaxHeapEntry(n));
424 }
425
426 while let Some(MinHeapEntry(closest)) = frontier.pop() {
427 if results.len() >= ef {
432 if let Some(MaxHeapEntry(worst)) = results.peek() {
433 if closest.distance > worst.distance {
434 break;
435 }
436 }
437 }
438
439 let neighbors_at_level = self.nodes[&closest.id]
440 .neighbors
441 .get(level)
442 .map(Vec::as_slice)
443 .unwrap_or(&[]);
444 for &nbr in neighbors_at_level {
445 if Some(nbr) == exclude {
446 continue;
447 }
448 if !visited.insert(nbr) {
449 continue;
450 }
451 let d = self.metric.distance(query, &self.nodes[&nbr].point);
452 let cand = Neighbor {
453 id: nbr,
454 distance: d,
455 };
456
457 let should_push = match results.peek() {
458 Some(MaxHeapEntry(worst)) => d < worst.distance || results.len() < ef,
459 None => true,
460 };
461 if should_push {
462 frontier.push(MinHeapEntry(cand));
463 results.push(MaxHeapEntry(cand));
464 if results.len() > ef {
465 results.pop();
466 }
467 }
468 }
469 }
470
471 results.into_iter().map(|MaxHeapEntry(n)| n).collect()
472 }
473
474 fn select_neighbors_heuristic(
481 &self,
482 mut candidates: Vec<Neighbor>,
483 m: usize,
484 keep_pruned: bool,
485 ) -> Vec<Neighbor> {
486 candidates.sort();
487
488 let mut selected: Vec<Neighbor> = Vec::with_capacity(m);
489 let mut discarded: Vec<Neighbor> = Vec::new();
490
491 for cand in candidates {
492 if selected.len() >= m {
493 break;
494 }
495 let dominated = selected.iter().any(|r| {
498 self.metric
499 .distance(&self.nodes[&cand.id].point, &self.nodes[&r.id].point)
500 <= cand.distance
501 });
502
503 if dominated {
504 discarded.push(cand);
505 } else {
506 selected.push(cand);
507 }
508 }
509
510 if keep_pruned {
511 for d in discarded {
512 if selected.len() >= m {
513 break;
514 }
515 selected.push(d);
516 }
517 }
518
519 selected
520 }
521
522 fn add_back_edge(&mut self, from: PointId, to: PointId, level: usize) {
523 let m_at_level = if level == 0 {
524 self.config.m_max0
525 } else {
526 self.config.m
527 };
528
529 let mut current_list: Vec<PointId> = {
534 let node = self
535 .nodes
536 .get_mut(&from)
537 .expect("from id exists in nodes map");
538 if node.neighbors.len() <= level {
539 node.neighbors.resize_with(level + 1, Vec::new);
540 }
541 if node.neighbors[level].contains(&to) {
542 return;
544 }
545 core::mem::take(&mut node.neighbors[level])
546 };
547
548 current_list.push(to);
549
550 if current_list.len() <= m_at_level {
552 self.nodes
553 .get_mut(&from)
554 .expect("from still present")
555 .neighbors[level] = current_list;
556 return;
557 }
558
559 let scored: Vec<Neighbor> = current_list
562 .iter()
563 .map(|&cid| {
564 let d = self
565 .metric
566 .distance(&self.nodes[&from].point, &self.nodes[&cid].point);
567 Neighbor {
568 id: cid,
569 distance: d,
570 }
571 })
572 .collect();
573
574 let kept_ids: Vec<PointId> = self
575 .select_neighbors_heuristic(scored, m_at_level, true)
576 .into_iter()
577 .map(|n| n.id)
578 .collect();
579
580 self.nodes
581 .get_mut(&from)
582 .expect("from still present")
583 .neighbors[level] = kept_ids;
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590 use crate::metric::L2;
591
592 fn make_index() -> HnswIndex<Vec<f32>, L2> {
593 HnswIndex::new(HnswConfig::default(), L2).expect("default config valid")
594 }
595
596 #[test]
597 fn empty_index_search_returns_empty() {
598 let idx = make_index();
599 assert!(idx.search(&vec![1.0, 2.0, 3.0], 5).is_empty());
600 }
601
602 #[test]
603 fn single_point_returns_itself() {
604 let mut idx = make_index();
605 idx.insert(42, vec![1.0, 2.0, 3.0]).unwrap();
606 let results = idx.search(&vec![1.0, 2.0, 3.0], 5);
607 assert_eq!(results.len(), 1);
608 assert_eq!(results[0].id, 42);
609 assert_eq!(results[0].distance, 0.0);
610 }
611
612 #[test]
613 fn duplicate_id_rejected() {
614 let mut idx = make_index();
615 idx.insert(7, vec![0.0, 0.0]).unwrap();
616 let err = idx.insert(7, vec![1.0, 1.0]).unwrap_err();
617 assert!(matches!(err, IndexError::DuplicateId(7)));
618 }
619
620 #[test]
621 fn dim_mismatch_rejected() {
622 let mut idx = make_index();
623 idx.insert(0, vec![0.0_f32; 64]).unwrap();
624 let err = idx.insert(1, vec![0.0_f32; 32]).unwrap_err();
625 assert!(
626 matches!(
627 err,
628 IndexError::DimensionMismatch {
629 expected: 64,
630 actual: 32
631 }
632 ),
633 "expected DimensionMismatch, got {err:?}"
634 );
635 }
636
637 #[test]
638 fn nearest_neighbor_is_correct_on_grid() {
639 let mut idx = make_index();
640 let mut id = 0;
642 for x in 0..5 {
643 for y in 0..5 {
644 idx.insert(id, vec![x as f32, y as f32]).unwrap();
645 id += 1;
646 }
647 }
648 let res = idx.search(&vec![2.1, 2.1], 1);
650 assert_eq!(res.len(), 1);
651 assert_eq!(res[0].id, 12, "nearest to (2.1, 2.1) should be (2,2)");
652 }
653
654 #[test]
655 fn k_nearest_neighbors_sorted_by_distance() {
656 let mut idx = make_index();
657 for i in 0..20 {
658 idx.insert(i, vec![i as f32, 0.0]).unwrap();
659 }
660 let res = idx.search(&vec![10.0, 0.0], 5);
661 assert_eq!(res.len(), 5);
662 for w in res.windows(2) {
664 assert!(w[0].distance <= w[1].distance);
665 }
666 assert_eq!(res[0].id, 10);
668 }
669
670 #[test]
671 fn recall_against_brute_force_random_data() {
672 use rand::{rngs::StdRng, SeedableRng};
673 use rand_distr::{Distribution, StandardNormal};
674
675 let mut rng = StdRng::seed_from_u64(42);
676 let n = 500;
677 let dim = 16;
678
679 let points: Vec<Vec<f32>> = (0..n)
681 .map(|_| (0..dim).map(|_| StandardNormal.sample(&mut rng)).collect())
682 .collect();
683
684 let mut idx = make_index();
685 for (i, p) in points.iter().enumerate() {
686 idx.insert(i as u64, p.clone()).unwrap();
687 }
688
689 let metric = L2;
691 let k = 10;
692 let n_queries = 10;
693 let mut total_recall = 0.0;
694
695 for _ in 0..n_queries {
696 let query: Vec<f32> = (0..dim).map(|_| StandardNormal.sample(&mut rng)).collect();
697
698 let hnsw_ids: HashSet<PointId> =
699 idx.search(&query, k).into_iter().map(|n| n.id).collect();
700
701 let mut bf: Vec<Neighbor> = points
702 .iter()
703 .enumerate()
704 .map(|(i, p)| Neighbor {
705 id: i as u64,
706 distance: metric.distance(&query, p),
707 })
708 .collect();
709 bf.sort();
710 let bf_ids: HashSet<PointId> = bf.into_iter().take(k).map(|n| n.id).collect();
711
712 let intersection = hnsw_ids.intersection(&bf_ids).count();
713 total_recall += intersection as f32 / k as f32;
714 }
715
716 let avg_recall = total_recall / n_queries as f32;
717 assert!(
718 avg_recall >= 0.95,
719 "recall {avg_recall:.3} below threshold; check HNSW correctness"
720 );
721 }
722
723 #[test]
724 #[ignore = "slow: ~400 s in debug; run with `cargo test -- --ignored`"]
725 fn recall_at_realistic_scale() {
726 use rand::{rngs::StdRng, SeedableRng};
727 use rand_distr::{Distribution, StandardNormal};
728
729 let n = 5000;
730 let dim = 64;
731 let k = 10;
732 let n_queries = 20;
733 let seeds: [u64; 5] = [1, 2, 3, 4, 5];
734 let metric = L2;
735
736 let mut total_recall = 0.0f32;
737
738 for seed in seeds {
739 let mut rng = StdRng::seed_from_u64(seed);
740
741 let points: Vec<Vec<f32>> = (0..n)
742 .map(|_| (0..dim).map(|_| StandardNormal.sample(&mut rng)).collect())
743 .collect();
744
745 let mut idx = make_index();
746 for (i, p) in points.iter().enumerate() {
747 idx.insert(i as u64, p.clone()).unwrap();
748 }
749
750 for _ in 0..n_queries {
751 let query: Vec<f32> = (0..dim).map(|_| StandardNormal.sample(&mut rng)).collect();
752
753 let hnsw_ids: HashSet<PointId> =
754 idx.search(&query, k).into_iter().map(|n| n.id).collect();
755
756 let mut bf: Vec<Neighbor> = points
757 .iter()
758 .enumerate()
759 .map(|(i, p)| Neighbor {
760 id: i as u64,
761 distance: metric.distance(&query, p),
762 })
763 .collect();
764 bf.sort();
765 let bf_ids: HashSet<PointId> = bf.into_iter().take(k).map(|n| n.id).collect();
766
767 total_recall += hnsw_ids.intersection(&bf_ids).count() as f32 / k as f32;
768 }
769 }
770
771 let mean_recall = total_recall / (seeds.len() * n_queries) as f32;
772 println!("recall_at_realistic_scale: mean_recall = {mean_recall:.4}");
773
774 assert!(
775 mean_recall >= 0.90,
776 "mean recall {mean_recall:.3} below 0.90; HNSW graph quality degraded"
777 );
778 assert!(
781 mean_recall < 0.999,
782 "mean recall {mean_recall:.3} implausibly perfect; test is no longer exercising ANN"
783 );
784 }
785}