1use anndists::prelude::Distance;
39use memmap2::Mmap;
40use rand::{prelude::*, thread_rng};
41use rayon::prelude::*;
42use serde::{Deserialize, Serialize};
43use std::cmp::{Ordering, Reverse};
44use std::collections::{BinaryHeap, HashSet};
45use std::fs::OpenOptions;
46use std::io::{Read, Seek, SeekFrom, Write};
47use std::marker::PhantomData;
48use thiserror::Error;
49
50const PAD_U32: u32 = u32::MAX;
52
53pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
55pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
56pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
57pub const DISKANN_DEFAULT_PASSES: usize = 1;
59pub const DISKANN_DEFAULT_EXTRA_SEEDS: usize = 1;
61
62const GRAPH_SLACK_FACTOR: f32 = 1.3;
66
67#[derive(Clone, Copy, Debug)]
69pub struct DiskAnnParams {
70 pub max_degree: usize,
71 pub build_beam_width: usize,
72 pub alpha: f32,
73 pub passes: usize,
75 pub extra_seeds: usize,
77}
78
79impl Default for DiskAnnParams {
80 fn default() -> Self {
81 Self {
82 max_degree: DISKANN_DEFAULT_MAX_DEGREE,
83 build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
84 alpha: DISKANN_DEFAULT_ALPHA,
85 passes: DISKANN_DEFAULT_PASSES,
86 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
87 }
88 }
89}
90
91#[derive(Debug, Error)]
93pub enum DiskAnnError {
94 #[error("I/O error: {0}")]
96 Io(#[from] std::io::Error),
97
98 #[error("Serialization error: {0}")]
100 Bincode(#[from] bincode::Error),
101
102 #[error("Index error: {0}")]
104 IndexError(String),
105}
106
107#[derive(Serialize, Deserialize, Debug)]
109struct Metadata {
110 dim: usize,
111 num_vectors: usize,
112 max_degree: usize,
113 medoid_id: u32,
114 vectors_offset: u64,
115 adjacency_offset: u64,
116 elem_size: u8,
117 distance_name: String,
118}
119
120#[derive(Clone, Copy, Debug)]
122struct Candidate {
123 dist: f32,
124 id: u32,
125}
126impl PartialEq for Candidate {
127 fn eq(&self, other: &Self) -> bool {
128 self.id == other.id && self.dist.to_bits() == other.dist.to_bits()
129 }
130}
131impl Eq for Candidate {}
132impl PartialOrd for Candidate {
133 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
134 Some(
135 self.dist
136 .total_cmp(&other.dist)
137 .then_with(|| self.id.cmp(&other.id)),
138 )
139 }
140}
141impl Ord for Candidate {
142 fn cmp(&self, other: &Self) -> Ordering {
143 self.partial_cmp(other).unwrap_or(Ordering::Equal)
144 }
145}
146
147#[derive(Clone, Debug)]
151struct FlatVectors<T> {
152 data: Vec<T>,
153 dim: usize,
154 n: usize,
155}
156
157impl<T: Copy> FlatVectors<T> {
158 fn from_vecs(vectors: &[Vec<T>]) -> Result<Self, DiskAnnError> {
159 if vectors.is_empty() {
160 return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
161 }
162 let dim = vectors[0].len();
163 for (i, v) in vectors.iter().enumerate() {
164 if v.len() != dim {
165 return Err(DiskAnnError::IndexError(format!(
166 "Vector {} has dimension {} but expected {}",
167 i,
168 v.len(),
169 dim
170 )));
171 }
172 }
173
174 let n = vectors.len();
175 let mut data = Vec::with_capacity(n * dim);
176 for v in vectors {
177 data.extend_from_slice(v);
178 }
179
180 Ok(Self { data, dim, n })
181 }
182
183 #[inline]
184 fn row(&self, idx: usize) -> &[T] {
185 let start = idx * self.dim;
186 let end = start + self.dim;
187 &self.data[start..end]
188 }
189}
190
191#[derive(Default, Debug)]
202struct OrderedBeam {
203 items: Vec<Candidate>,
204}
205
206impl OrderedBeam {
207 #[inline]
208 fn clear(&mut self) {
209 self.items.clear();
210 }
211
212 #[inline]
213 fn len(&self) -> usize {
214 self.items.len()
215 }
216
217 #[inline]
218 fn is_empty(&self) -> bool {
219 self.items.is_empty()
220 }
221
222 #[inline]
223 fn best(&self) -> Option<Candidate> {
224 self.items.last().copied()
225 }
226
227 #[inline]
228 fn worst(&self) -> Option<Candidate> {
229 self.items.first().copied()
230 }
231
232 #[inline]
233 fn pop_best(&mut self) -> Option<Candidate> {
234 self.items.pop()
235 }
236
237 #[inline]
238 fn reserve(&mut self, cap: usize) {
239 if self.items.capacity() < cap {
240 self.items.reserve(cap - self.items.capacity());
241 }
242 }
243
244 #[inline]
245 fn insert_unbounded(&mut self, cand: Candidate) {
246 let pos = self.items.partition_point(|x| {
247 x.dist > cand.dist || (x.dist.to_bits() == cand.dist.to_bits() && x.id > cand.id)
248 });
249 self.items.insert(pos, cand);
250 }
251
252 #[inline]
253 fn insert_capped(&mut self, cand: Candidate, cap: usize) {
254 if cap == 0 {
255 return;
256 }
257
258 if self.items.len() < cap {
259 self.insert_unbounded(cand);
260 return;
261 }
262
263 let worst = self.items[0];
265 if cand.dist >= worst.dist {
266 return;
267 }
268
269 self.insert_unbounded(cand);
270
271 if self.items.len() > cap {
272 self.items.remove(0);
273 }
274 }
275}
276
277#[derive(Debug)]
281struct BuildScratch {
282 marks: Vec<u32>,
283 epoch: u32,
284
285 visited_ids: Vec<u32>,
286 visited_dists: Vec<f32>,
287
288 frontier: OrderedBeam,
289 work: OrderedBeam,
290
291 seeds: Vec<usize>,
292 candidates: Vec<(u32, f32)>,
293}
294
295impl BuildScratch {
296 fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
297 Self {
298 marks: vec![0u32; n],
299 epoch: 1,
300 visited_ids: Vec::with_capacity(beam_width * 4),
301 visited_dists: Vec::with_capacity(beam_width * 4),
302 frontier: {
303 let mut b = OrderedBeam::default();
304 b.reserve(beam_width * 2);
305 b
306 },
307 work: {
308 let mut b = OrderedBeam::default();
309 b.reserve(beam_width * 2);
310 b
311 },
312 seeds: Vec::with_capacity(1 + extra_seeds),
313 candidates: Vec::with_capacity(beam_width * (4 + extra_seeds) + max_degree * 2),
314 }
315 }
316
317 #[inline]
318 fn reset_search(&mut self) {
319 self.epoch = self.epoch.wrapping_add(1);
320 if self.epoch == 0 {
321 self.marks.fill(0);
322 self.epoch = 1;
323 }
324 self.visited_ids.clear();
325 self.visited_dists.clear();
326 self.frontier.clear();
327 self.work.clear();
328 }
329
330 #[inline]
331 fn is_marked(&self, idx: usize) -> bool {
332 self.marks[idx] == self.epoch
333 }
334
335 #[inline]
336 fn mark_with_dist(&mut self, idx: usize, dist: f32) {
337 self.marks[idx] = self.epoch;
338 self.visited_ids.push(idx as u32);
339 self.visited_dists.push(dist);
340 }
341}
342
343pub struct DiskANN<T, D>
345where
346 T: bytemuck::Pod + Copy + Send + Sync + 'static,
347 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
348{
349 pub dim: usize,
351 pub num_vectors: usize,
353 pub max_degree: usize,
355 pub distance_name: String,
357
358 medoid_id: u32,
360 vectors_offset: u64,
362 adjacency_offset: u64,
363
364 mmap: Mmap,
366
367 dist: D,
369
370 _phantom: PhantomData<T>,
372}
373
374impl<T, D> DiskANN<T, D>
377where
378 T: bytemuck::Pod + Copy + Send + Sync + 'static,
379 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
380{
381 pub fn build_index_default(
383 vectors: &[Vec<T>],
384 dist: D,
385 file_path: &str,
386 ) -> Result<Self, DiskAnnError> {
387 Self::build_index(
388 vectors,
389 DISKANN_DEFAULT_MAX_DEGREE,
390 DISKANN_DEFAULT_BUILD_BEAM,
391 DISKANN_DEFAULT_ALPHA,
392 DISKANN_DEFAULT_PASSES,
393 DISKANN_DEFAULT_EXTRA_SEEDS,
394 dist,
395 file_path,
396 )
397 }
398
399 pub fn build_index_with_params(
401 vectors: &[Vec<T>],
402 dist: D,
403 file_path: &str,
404 p: DiskAnnParams,
405 ) -> Result<Self, DiskAnnError> {
406 Self::build_index(
407 vectors,
408 p.max_degree,
409 p.build_beam_width,
410 p.alpha,
411 p.passes,
412 p.extra_seeds,
413 dist,
414 file_path,
415 )
416 }
417
418 pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
420 let mut file = OpenOptions::new().read(true).write(false).open(path)?;
421
422 let mut buf8 = [0u8; 8];
424 file.seek(SeekFrom::Start(0))?;
425 file.read_exact(&mut buf8)?;
426 let md_len = u64::from_le_bytes(buf8);
427
428 let mut md_bytes = vec![0u8; md_len as usize];
430 file.read_exact(&mut md_bytes)?;
431 let metadata: Metadata = bincode::deserialize(&md_bytes)?;
432
433 let mmap = unsafe { memmap2::Mmap::map(&file)? };
434
435 let want = std::mem::size_of::<T>() as u8;
437 if metadata.elem_size != want {
438 return Err(DiskAnnError::IndexError(format!(
439 "element size mismatch: file has {}B, T is {}B",
440 metadata.elem_size, want
441 )));
442 }
443
444 let expected = std::any::type_name::<D>();
446 if metadata.distance_name != expected {
447 eprintln!(
448 "Warning: index recorded distance `{}` but you opened with `{}`",
449 metadata.distance_name, expected
450 );
451 }
452
453 Ok(Self {
454 dim: metadata.dim,
455 num_vectors: metadata.num_vectors,
456 max_degree: metadata.max_degree,
457 distance_name: metadata.distance_name,
458 medoid_id: metadata.medoid_id,
459 vectors_offset: metadata.vectors_offset,
460 adjacency_offset: metadata.adjacency_offset,
461 mmap,
462 dist,
463 _phantom: PhantomData,
464 })
465 }
466}
467
468impl<T, D> DiskANN<T, D>
470where
471 T: bytemuck::Pod + Copy + Send + Sync + 'static,
472 D: Distance<T> + Default + Send + Sync + Copy + Clone + 'static,
473{
474 pub fn build_index_default_metric(
476 vectors: &[Vec<T>],
477 file_path: &str,
478 ) -> Result<Self, DiskAnnError> {
479 Self::build_index_default(vectors, D::default(), file_path)
480 }
481
482 pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
484 Self::open_index_with(path, D::default())
485 }
486}
487
488impl<T, D> DiskANN<T, D>
489where
490 T: bytemuck::Pod + Copy + Send + Sync + 'static,
491 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
492{
493 pub fn build_index(
505 vectors: &[Vec<T>],
506 max_degree: usize,
507 build_beam_width: usize,
508 alpha: f32,
509 passes: usize,
510 extra_seeds: usize,
511 dist: D,
512 file_path: &str,
513 ) -> Result<Self, DiskAnnError> {
514 let flat = FlatVectors::from_vecs(vectors)?;
515
516 let num_vectors = flat.n;
517 let dim = flat.dim;
518
519 let mut file = OpenOptions::new()
520 .create(true)
521 .write(true)
522 .read(true)
523 .truncate(true)
524 .open(file_path)?;
525
526 let vectors_offset = 1024 * 1024;
528 assert_eq!(
529 (vectors_offset as usize) % std::mem::align_of::<T>(),
530 0,
531 "vectors_offset must be aligned for T"
532 );
533
534 let elem_sz = std::mem::size_of::<T>() as u64;
535 let total_vector_bytes = (num_vectors as u64) * (dim as u64) * elem_sz;
536
537 file.seek(SeekFrom::Start(vectors_offset as u64))?;
539 file.write_all(bytemuck::cast_slice::<T, u8>(&flat.data))?;
540
541 let medoid_id = calculate_medoid(&flat, dist);
543
544 let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
546 let graph = build_vamana_graph(
547 &flat,
548 max_degree,
549 build_beam_width,
550 alpha,
551 passes,
552 extra_seeds,
553 dist,
554 medoid_id as u32,
555 );
556
557 file.seek(SeekFrom::Start(adjacency_offset))?;
559 for neighbors in &graph {
560 let mut padded = neighbors.clone();
561 padded.resize(max_degree, PAD_U32);
562 let bytes = bytemuck::cast_slice::<u32, u8>(&padded);
563 file.write_all(bytes)?;
564 }
565
566 let metadata = Metadata {
568 dim,
569 num_vectors,
570 max_degree,
571 medoid_id: medoid_id as u32,
572 vectors_offset: vectors_offset as u64,
573 adjacency_offset,
574 elem_size: std::mem::size_of::<T>() as u8,
575 distance_name: std::any::type_name::<D>().to_string(),
576 };
577
578 let md_bytes = bincode::serialize(&metadata)?;
579 file.seek(SeekFrom::Start(0))?;
580 let md_len = md_bytes.len() as u64;
581 file.write_all(&md_len.to_le_bytes())?;
582 file.write_all(&md_bytes)?;
583 file.sync_all()?;
584
585 let mmap = unsafe { memmap2::Mmap::map(&file)? };
587
588 Ok(Self {
589 dim,
590 num_vectors,
591 max_degree,
592 distance_name: metadata.distance_name,
593 medoid_id: metadata.medoid_id,
594 vectors_offset: metadata.vectors_offset,
595 adjacency_offset: metadata.adjacency_offset,
596 mmap,
597 dist,
598 _phantom: PhantomData,
599 })
600 }
601
602 pub fn search_with_dists(&self, query: &[T], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
605 assert_eq!(
606 query.len(),
607 self.dim,
608 "Query dim {} != index dim {}",
609 query.len(),
610 self.dim
611 );
612
613 let mut visited = HashSet::new();
614 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
615 let mut w: BinaryHeap<Candidate> = BinaryHeap::new();
616
617 let start_dist = self.distance_to(query, self.medoid_id as usize);
618 let start = Candidate {
619 dist: start_dist,
620 id: self.medoid_id,
621 };
622 frontier.push(Reverse(start));
623 w.push(start);
624 visited.insert(self.medoid_id);
625
626 while let Some(Reverse(best)) = frontier.peek().copied() {
627 if w.len() >= beam_width {
628 if let Some(worst) = w.peek() {
629 if best.dist >= worst.dist {
630 break;
631 }
632 }
633 }
634 let Reverse(current) = frontier.pop().unwrap();
635
636 for &nb in self.get_neighbors(current.id) {
637 if nb == PAD_U32 {
638 continue;
639 }
640 if !visited.insert(nb) {
641 continue;
642 }
643
644 let d = self.distance_to(query, nb as usize);
645 let cand = Candidate { dist: d, id: nb };
646
647 if w.len() < beam_width {
648 w.push(cand);
649 frontier.push(Reverse(cand));
650 } else if d < w.peek().unwrap().dist {
651 w.pop();
652 w.push(cand);
653 frontier.push(Reverse(cand));
654 }
655 }
656 }
657
658 let mut results: Vec<_> = w.into_vec();
659 results.sort_by(|a, b| a.dist.total_cmp(&b.dist));
660 results.truncate(k);
661 results.into_iter().map(|c| (c.id, c.dist)).collect()
662 }
663
664 pub fn search(&self, query: &[T], k: usize, beam_width: usize) -> Vec<u32> {
666 self.search_with_dists(query, k, beam_width)
667 .into_iter()
668 .map(|(id, _dist)| id)
669 .collect()
670 }
671
672 fn get_neighbors(&self, node_id: u32) -> &[u32] {
674 let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
675 let start = offset as usize;
676 let end = start + (self.max_degree * 4);
677 let bytes = &self.mmap[start..end];
678 bytemuck::cast_slice(bytes)
679 }
680
681 fn distance_to(&self, query: &[T], idx: usize) -> f32 {
683 let elem_sz = std::mem::size_of::<T>();
684 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
685 let start = offset as usize;
686 let end = start + (self.dim * elem_sz);
687 let bytes = &self.mmap[start..end];
688 let vector: &[T] = bytemuck::cast_slice(bytes);
689 self.dist.eval(query, vector)
690 }
691
692 pub fn get_vector(&self, idx: usize) -> Vec<T> {
694 let elem_sz = std::mem::size_of::<T>();
695 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
696 let start = offset as usize;
697 let end = start + (self.dim * elem_sz);
698 let bytes = &self.mmap[start..end];
699 let vector: &[T] = bytemuck::cast_slice(bytes);
700 vector.to_vec()
701 }
702}
703
704fn calculate_medoid<T, D>(vectors: &FlatVectors<T>, dist: D) -> usize
706where
707 T: bytemuck::Pod + Copy + Send + Sync,
708 D: Distance<T> + Copy + Sync,
709{
710 let n = vectors.n;
711 let k = 8.min(n);
712 let mut rng = thread_rng();
713 let pivots: Vec<usize> = (0..k).map(|_| rng.gen_range(0..n)).collect();
714
715 let (best_idx, _best_score) = (0..n)
716 .into_par_iter()
717 .map(|i| {
718 let vi = vectors.row(i);
719 let score: f32 = pivots.iter().map(|&p| dist.eval(vi, vectors.row(p))).sum();
720 (i, score)
721 })
722 .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
723
724 best_idx
725}
726
727fn build_vamana_graph<T, D>(
729 vectors: &FlatVectors<T>,
730 max_degree: usize,
731 build_beam_width: usize,
732 alpha: f32,
733 passes: usize,
734 extra_seeds: usize,
735 dist: D,
736 medoid_id: u32,
737) -> Vec<Vec<u32>>
738where
739 T: bytemuck::Pod + Copy + Send + Sync,
740 D: Distance<T> + Copy + Sync,
741{
742 let n = vectors.n;
743 let mut graph = vec![Vec::<u32>::new(); n];
744
745 {
747 let mut rng = thread_rng();
748 let target = max_degree.min(n.saturating_sub(1));
749
750 for i in 0..n {
751 let mut s = HashSet::with_capacity(target);
752 while s.len() < target {
753 let nb = rng.gen_range(0..n);
754 if nb != i {
755 s.insert(nb as u32);
756 }
757 }
758 graph[i] = s.into_iter().collect();
759 }
760 }
761
762 let passes = passes.max(1);
763 let mut rng = thread_rng();
764
765 for pass_idx in 0..passes {
766 let pass_alpha = if passes == 1 {
767 alpha
768 } else if pass_idx == 0 {
769 1.0
770 } else {
771 alpha
772 };
773
774 let mut order: Vec<usize> = (0..n).collect();
775 order.shuffle(&mut rng);
776
777 for &u in &order {
779 let mut scratch = BuildScratch::new(n, build_beam_width, max_degree, extra_seeds);
780 scratch.candidates.clear();
781
782 for &nb in &graph[u] {
785 let d = dist.eval(vectors.row(u), vectors.row(nb as usize));
786 scratch.candidates.push((nb, d));
787 }
788
789 scratch.seeds.clear();
791 scratch.seeds.push(medoid_id as usize);
792 while scratch.seeds.len() < 1 + extra_seeds {
793 let s = rng.gen_range(0..n);
794 if !scratch.seeds.contains(&s) {
795 scratch.seeds.push(s);
796 }
797 }
798
799 let seeds = scratch.seeds.clone();
800 for start in seeds {
801 greedy_search_visited_collect(
802 vectors.row(u),
803 vectors,
804 &graph,
805 start,
806 build_beam_width,
807 dist,
808 &mut scratch,
809 );
810
811 for i in 0..scratch.visited_ids.len() {
812 scratch
813 .candidates
814 .push((scratch.visited_ids[i], scratch.visited_dists[i]));
815 }
816 }
817
818 scratch.candidates.sort_by(|a, b| a.0.cmp(&b.0));
820 scratch.candidates.dedup_by(|a, b| {
821 if a.0 == b.0 {
822 if a.1 < b.1 {
823 *b = *a;
824 }
825 true
826 } else {
827 false
828 }
829 });
830
831 let pruned = prune_neighbors(
832 u,
833 &scratch.candidates,
834 vectors,
835 max_degree,
836 pass_alpha,
837 dist,
838 );
839
840 graph[u] = pruned.clone();
842
843 inter_insert_with_slack(
845 &mut graph,
846 u as u32,
847 &pruned,
848 vectors,
849 max_degree,
850 pass_alpha,
851 dist,
852 );
853 }
854 }
855
856 graph
858 .into_iter()
859 .enumerate()
860 .map(|(u, neigh)| {
861 if neigh.len() <= max_degree {
862 return neigh;
863 }
864 let mut ids = neigh;
865 ids.sort_unstable();
866 ids.dedup();
867
868 let pool: Vec<(u32, f32)> = ids
869 .into_iter()
870 .filter(|&id| id as usize != u)
871 .map(|id| (id, dist.eval(vectors.row(u), vectors.row(id as usize))))
872 .collect();
873
874 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
875 })
876 .collect()
877}
878
879fn inter_insert_with_slack<T, D>(
883 graph: &mut [Vec<u32>],
884 src: u32,
885 pruned_list: &[u32],
886 vectors: &FlatVectors<T>,
887 max_degree: usize,
888 alpha: f32,
889 dist: D,
890) where
891 T: bytemuck::Pod + Copy + Send + Sync,
892 D: Distance<T> + Copy,
893{
894 let slack_limit = ((GRAPH_SLACK_FACTOR * max_degree as f32).ceil() as usize).max(max_degree);
895
896 for &dst in pruned_list {
897 let dst_usize = dst as usize;
898 let src_usize = src as usize;
899
900 if dst_usize == src_usize {
901 continue;
902 }
903
904 if graph[dst_usize].contains(&src) {
906 continue;
907 }
908
909 if graph[dst_usize].len() < slack_limit {
910 graph[dst_usize].push(src);
911 continue;
912 }
913
914 let mut ids = graph[dst_usize].clone();
916 ids.push(src);
917 ids.sort_unstable();
918 ids.dedup();
919
920 let pool: Vec<(u32, f32)> = ids
921 .into_iter()
922 .filter(|&id| id as usize != dst_usize)
923 .map(|id| (id, dist.eval(vectors.row(dst_usize), vectors.row(id as usize))))
924 .collect();
925
926 graph[dst_usize] = prune_neighbors(dst_usize, &pool, vectors, max_degree, alpha, dist);
927 }
928}
929
930fn greedy_search_visited_collect<T, D>(
937 query: &[T],
938 vectors: &FlatVectors<T>,
939 graph: &[Vec<u32>],
940 start_id: usize,
941 beam_width: usize,
942 dist: D,
943 scratch: &mut BuildScratch,
944) where
945 T: bytemuck::Pod + Copy + Send + Sync,
946 D: Distance<T> + Copy,
947{
948 scratch.reset_search();
949
950 let start_dist = dist.eval(query, vectors.row(start_id));
951 let start = Candidate {
952 dist: start_dist,
953 id: start_id as u32,
954 };
955
956 scratch.frontier.insert_unbounded(start);
957 scratch.work.insert_capped(start, beam_width);
958 scratch.mark_with_dist(start_id, start_dist);
959
960 while !scratch.frontier.is_empty() {
961 let best = scratch.frontier.best().unwrap();
962 if scratch.work.len() >= beam_width {
963 if let Some(worst) = scratch.work.worst() {
964 if best.dist >= worst.dist {
965 break;
966 }
967 }
968 }
969
970 let cur = scratch.frontier.pop_best().unwrap();
971
972 for &nb in &graph[cur.id as usize] {
973 let nb_usize = nb as usize;
974 if scratch.is_marked(nb_usize) {
975 continue;
976 }
977
978 let d = dist.eval(query, vectors.row(nb_usize));
979 scratch.mark_with_dist(nb_usize, d);
980
981 let cand = Candidate { dist: d, id: nb };
982
983 if scratch.work.len() < beam_width {
984 scratch.work.insert_unbounded(cand);
985 scratch.frontier.insert_unbounded(cand);
986 } else if let Some(worst) = scratch.work.worst() {
987 if d < worst.dist {
988 scratch.work.insert_capped(cand, beam_width);
989 scratch.frontier.insert_unbounded(cand);
990 }
991 }
992 }
993 }
994}
995
996fn prune_neighbors<T, D>(
998 node_id: usize,
999 candidates: &[(u32, f32)],
1000 vectors: &FlatVectors<T>,
1001 max_degree: usize,
1002 alpha: f32,
1003 dist: D,
1004) -> Vec<u32>
1005where
1006 T: bytemuck::Pod + Copy + Send + Sync,
1007 D: Distance<T> + Copy,
1008{
1009 if candidates.is_empty() || max_degree == 0 {
1010 return Vec::new();
1011 }
1012
1013 let mut sorted = candidates.to_vec();
1015 sorted.sort_by(|a, b| a.1.total_cmp(&b.1));
1016
1017 let mut uniq = Vec::<(u32, f32)>::with_capacity(sorted.len());
1019 let mut last_id: Option<u32> = None;
1020 for &(cand_id, cand_dist) in &sorted {
1021 if cand_id as usize == node_id {
1022 continue;
1023 }
1024 if last_id == Some(cand_id) {
1025 continue;
1026 }
1027 uniq.push((cand_id, cand_dist));
1028 last_id = Some(cand_id);
1029 }
1030
1031 let mut pruned = Vec::<u32>::with_capacity(max_degree);
1032
1033 for &(cand_id, cand_dist_to_node) in &uniq {
1035 let mut occluded = false;
1036
1037 for &sel_id in &pruned {
1038 let d_cand_sel = dist.eval(
1039 vectors.row(cand_id as usize),
1040 vectors.row(sel_id as usize),
1041 );
1042
1043 if alpha * d_cand_sel <= cand_dist_to_node {
1045 occluded = true;
1046 break;
1047 }
1048 }
1049
1050 if !occluded {
1051 pruned.push(cand_id);
1052 if pruned.len() >= max_degree {
1053 break;
1054 }
1055 }
1056 }
1057
1058 pruned
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063 use super::*;
1064 use anndists::dist::{DistCosine, DistL2};
1065 use rand::Rng;
1066 use std::fs;
1067
1068 fn euclid(a: &[f32], b: &[f32]) -> f32 {
1069 a.iter()
1070 .zip(b)
1071 .map(|(x, y)| (x - y) * (x - y))
1072 .sum::<f32>()
1073 .sqrt()
1074 }
1075
1076 #[test]
1077 fn test_small_index_l2() {
1078 let path = "test_small_l2.db";
1079 let _ = fs::remove_file(path);
1080
1081 let vectors = vec![
1082 vec![0.0, 0.0],
1083 vec![1.0, 0.0],
1084 vec![0.0, 1.0],
1085 vec![1.0, 1.0],
1086 vec![0.5, 0.5],
1087 ];
1088
1089 let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1090
1091 let q = vec![0.1, 0.1];
1092 let nns = index.search(&q, 3, 8);
1093 assert_eq!(nns.len(), 3);
1094
1095 let v = index.get_vector(nns[0] as usize);
1096 assert!(euclid(&q, &v) < 1.0);
1097
1098 let _ = fs::remove_file(path);
1099 }
1100
1101 #[test]
1102 fn test_cosine() {
1103 let path = "test_cosine.db";
1104 let _ = fs::remove_file(path);
1105
1106 let vectors = vec![
1107 vec![1.0, 0.0, 0.0],
1108 vec![0.0, 1.0, 0.0],
1109 vec![0.0, 0.0, 1.0],
1110 vec![1.0, 1.0, 0.0],
1111 vec![1.0, 0.0, 1.0],
1112 ];
1113
1114 let index =
1115 DiskANN::<f32, DistCosine>::build_index_default(&vectors, DistCosine, path).unwrap();
1116
1117 let q = vec![2.0, 0.0, 0.0];
1118 let nns = index.search(&q, 2, 8);
1119 assert_eq!(nns.len(), 2);
1120
1121 let v = index.get_vector(nns[0] as usize);
1122 let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
1123 let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1124 let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
1125 let cos = dot / (n1 * n2);
1126 assert!(cos > 0.7);
1127
1128 let _ = fs::remove_file(path);
1129 }
1130
1131 #[test]
1132 fn test_persistence_and_open() {
1133 let path = "test_persist.db";
1134 let _ = fs::remove_file(path);
1135
1136 let vectors = vec![
1137 vec![0.0, 0.0],
1138 vec![1.0, 0.0],
1139 vec![0.0, 1.0],
1140 vec![1.0, 1.0],
1141 ];
1142
1143 {
1144 let _idx =
1145 DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1146 }
1147
1148 let idx2 = DiskANN::<f32, DistL2>::open_index_default_metric(path).unwrap();
1149 assert_eq!(idx2.num_vectors, 4);
1150 assert_eq!(idx2.dim, 2);
1151
1152 let q = vec![0.9, 0.9];
1153 let res = idx2.search(&q, 2, 8);
1154 assert_eq!(res[0], 3);
1155
1156 let _ = fs::remove_file(path);
1157 }
1158
1159 #[test]
1160 fn test_grid_connectivity() {
1161 let path = "test_grid.db";
1162 let _ = fs::remove_file(path);
1163
1164 let mut vectors = Vec::new();
1165 for i in 0..5 {
1166 for j in 0..5 {
1167 vectors.push(vec![i as f32, j as f32]);
1168 }
1169 }
1170
1171 let index = DiskANN::<f32, DistL2>::build_index_with_params(
1172 &vectors,
1173 DistL2,
1174 path,
1175 DiskAnnParams {
1176 max_degree: 4,
1177 build_beam_width: 64,
1178 alpha: 1.5,
1179 passes: DISKANN_DEFAULT_PASSES,
1180 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1181 },
1182 )
1183 .unwrap();
1184
1185 for target in 0..vectors.len() {
1186 let q = &vectors[target];
1187 let nns = index.search(q, 10, 32);
1188 if !nns.contains(&(target as u32)) {
1189 let v = index.get_vector(nns[0] as usize);
1190 assert!(euclid(q, &v) < 2.0);
1191 }
1192 for &nb in nns.iter().take(5) {
1193 let v = index.get_vector(nb as usize);
1194 assert!(euclid(q, &v) < 5.0);
1195 }
1196 }
1197
1198 let _ = fs::remove_file(path);
1199 }
1200
1201 #[test]
1202 fn test_medium_random() {
1203 let path = "test_medium.db";
1204 let _ = fs::remove_file(path);
1205
1206 let n = 200usize;
1207 let d = 32usize;
1208 let mut rng = rand::thread_rng();
1209 let vectors: Vec<Vec<f32>> = (0..n)
1210 .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
1211 .collect();
1212
1213 let index = DiskANN::<f32, DistL2>::build_index_with_params(
1214 &vectors,
1215 DistL2,
1216 path,
1217 DiskAnnParams {
1218 max_degree: 32,
1219 build_beam_width: 128,
1220 alpha: 1.2,
1221 passes: DISKANN_DEFAULT_PASSES,
1222 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1223 },
1224 )
1225 .unwrap();
1226
1227 let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1228 let res = index.search(&q, 10, 64);
1229 assert_eq!(res.len(), 10);
1230
1231 let dists: Vec<f32> = res
1232 .iter()
1233 .map(|&id| {
1234 let v = index.get_vector(id as usize);
1235 euclid(&q, &v)
1236 })
1237 .collect();
1238 let mut sorted = dists.clone();
1239 sorted.sort_by(|a, b| a.total_cmp(b));
1240 assert_eq!(dists, sorted);
1241
1242 let _ = fs::remove_file(path);
1243 }
1244}