1use anndists::prelude::Distance;
39use memmap2::Mmap;
40use rand::{prelude::*, thread_rng};
41use rayon::prelude::*;
42use serde::{Deserialize, Serialize};
43use std::cmp::{Ordering, Reverse};
44use std::collections::{BinaryHeap, HashSet};
45use std::fs::OpenOptions;
46use std::io::{Read, Seek, SeekFrom, Write};
47use std::marker::PhantomData;
48use thiserror::Error;
49
50const PAD_U32: u32 = u32::MAX;
52
53pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
55pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
56pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
57pub const DISKANN_DEFAULT_PASSES: usize = 2;
59pub const DISKANN_DEFAULT_EXTRA_SEEDS: usize = 2;
61
62#[derive(Clone, Copy, Debug)]
64pub struct DiskAnnParams {
65 pub max_degree: usize,
66 pub build_beam_width: usize,
67 pub alpha: f32,
68 pub passes: usize,
70 pub extra_seeds: usize,
72}
73
74impl Default for DiskAnnParams {
75 fn default() -> Self {
76 Self {
77 max_degree: DISKANN_DEFAULT_MAX_DEGREE,
78 build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
79 alpha: DISKANN_DEFAULT_ALPHA,
80 passes: DISKANN_DEFAULT_PASSES,
81 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
82 }
83 }
84}
85
86#[derive(Debug, Error)]
88pub enum DiskAnnError {
89 #[error("I/O error: {0}")]
91 Io(#[from] std::io::Error),
92
93 #[error("Serialization error: {0}")]
95 Bincode(#[from] bincode::Error),
96
97 #[error("Index error: {0}")]
99 IndexError(String),
100}
101
102#[derive(Serialize, Deserialize, Debug)]
104struct Metadata {
105 dim: usize,
106 num_vectors: usize,
107 max_degree: usize,
108 medoid_id: u32,
109 vectors_offset: u64,
110 adjacency_offset: u64,
111 elem_size: u8,
112 distance_name: String,
113}
114
115#[derive(Clone, Copy, Debug)]
117struct Candidate {
118 dist: f32,
119 id: u32,
120}
121impl PartialEq for Candidate {
122 fn eq(&self, other: &Self) -> bool {
123 self.id == other.id && self.dist.to_bits() == other.dist.to_bits()
124 }
125}
126impl Eq for Candidate {}
127impl PartialOrd for Candidate {
128 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
129 Some(
130 self.dist
131 .total_cmp(&other.dist)
132 .then_with(|| self.id.cmp(&other.id)),
133 )
134 }
135}
136impl Ord for Candidate {
137 fn cmp(&self, other: &Self) -> Ordering {
138 self.partial_cmp(other).unwrap_or(Ordering::Equal)
139 }
140}
141
142#[derive(Clone, Debug)]
146struct FlatVectors<T> {
147 data: Vec<T>,
148 dim: usize,
149 n: usize,
150}
151
152impl<T: Copy> FlatVectors<T> {
153 fn from_vecs(vectors: &[Vec<T>]) -> Result<Self, DiskAnnError> {
154 if vectors.is_empty() {
155 return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
156 }
157 let dim = vectors[0].len();
158 for (i, v) in vectors.iter().enumerate() {
159 if v.len() != dim {
160 return Err(DiskAnnError::IndexError(format!(
161 "Vector {} has dimension {} but expected {}",
162 i,
163 v.len(),
164 dim
165 )));
166 }
167 }
168
169 let n = vectors.len();
170 let mut data = Vec::with_capacity(n * dim);
171 for v in vectors {
172 data.extend_from_slice(v);
173 }
174
175 Ok(Self { data, dim, n })
176 }
177
178 #[inline]
179 fn row(&self, idx: usize) -> &[T] {
180 let start = idx * self.dim;
181 let end = start + self.dim;
182 &self.data[start..end]
183 }
184}
185
186#[derive(Default, Debug)]
197struct OrderedBeam {
198 items: Vec<Candidate>,
199}
200
201impl OrderedBeam {
202 #[inline]
203 fn clear(&mut self) {
204 self.items.clear();
205 }
206
207 #[inline]
208 fn len(&self) -> usize {
209 self.items.len()
210 }
211
212 #[inline]
213 fn is_empty(&self) -> bool {
214 self.items.is_empty()
215 }
216
217 #[inline]
218 fn best(&self) -> Option<Candidate> {
219 self.items.last().copied()
220 }
221
222 #[inline]
223 fn worst(&self) -> Option<Candidate> {
224 self.items.first().copied()
225 }
226
227 #[inline]
228 fn pop_best(&mut self) -> Option<Candidate> {
229 self.items.pop()
230 }
231
232 #[inline]
233 fn reserve(&mut self, cap: usize) {
234 if self.items.capacity() < cap {
235 self.items.reserve(cap - self.items.capacity());
236 }
237 }
238
239 #[inline]
240 fn insert_unbounded(&mut self, cand: Candidate) {
241 let pos = self.items.partition_point(|x| {
242 x.dist > cand.dist || (x.dist.to_bits() == cand.dist.to_bits() && x.id > cand.id)
243 });
244 self.items.insert(pos, cand);
245 }
246
247 #[inline]
248 fn insert_capped(&mut self, cand: Candidate, cap: usize) {
249 if cap == 0 {
250 return;
251 }
252
253 if self.items.len() < cap {
254 self.insert_unbounded(cand);
255 return;
256 }
257
258 let worst = self.items[0];
260 if cand.dist >= worst.dist {
261 return;
262 }
263
264 self.insert_unbounded(cand);
265
266 if self.items.len() > cap {
267 self.items.remove(0);
268 }
269 }
270}
271
272#[derive(Debug)]
276struct BuildScratch {
277 marks: Vec<u32>,
278 epoch: u32,
279
280 visited_ids: Vec<u32>,
281 visited_dists: Vec<f32>,
282
283 frontier: OrderedBeam,
284 work: OrderedBeam,
285
286 seeds: Vec<usize>,
287 candidates: Vec<(u32, f32)>,
288}
289
290impl BuildScratch {
291 fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
292 Self {
293 marks: vec![0u32; n],
294 epoch: 1,
295 visited_ids: Vec::with_capacity(beam_width * 4),
296 visited_dists: Vec::with_capacity(beam_width * 4),
297 frontier: {
298 let mut b = OrderedBeam::default();
299 b.reserve(beam_width * 2);
300 b
301 },
302 work: {
303 let mut b = OrderedBeam::default();
304 b.reserve(beam_width * 2);
305 b
306 },
307 seeds: Vec::with_capacity(1 + extra_seeds),
308 candidates: Vec::with_capacity(beam_width * (4 + extra_seeds) + max_degree * 2),
309 }
310 }
311
312 #[inline]
313 fn reset_search(&mut self) {
314 self.epoch = self.epoch.wrapping_add(1);
315 if self.epoch == 0 {
316 self.marks.fill(0);
317 self.epoch = 1;
318 }
319 self.visited_ids.clear();
320 self.visited_dists.clear();
321 self.frontier.clear();
322 self.work.clear();
323 }
324
325 #[inline]
326 fn is_marked(&self, idx: usize) -> bool {
327 self.marks[idx] == self.epoch
328 }
329
330 #[inline]
331 fn mark_with_dist(&mut self, idx: usize, dist: f32) {
332 self.marks[idx] = self.epoch;
333 self.visited_ids.push(idx as u32);
334 self.visited_dists.push(dist);
335 }
336}
337
338pub struct DiskANN<T, D>
340where
341 T: bytemuck::Pod + Copy + Send + Sync + 'static,
342 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
343{
344 pub dim: usize,
346 pub num_vectors: usize,
348 pub max_degree: usize,
350 pub distance_name: String,
352
353 medoid_id: u32,
355 vectors_offset: u64,
357 adjacency_offset: u64,
358
359 mmap: Mmap,
361
362 dist: D,
364
365 _phantom: PhantomData<T>,
367}
368
369impl<T, D> DiskANN<T, D>
372where
373 T: bytemuck::Pod + Copy + Send + Sync + 'static,
374 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
375{
376 pub fn build_index_default(
378 vectors: &[Vec<T>],
379 dist: D,
380 file_path: &str,
381 ) -> Result<Self, DiskAnnError> {
382 Self::build_index(
383 vectors,
384 DISKANN_DEFAULT_MAX_DEGREE,
385 DISKANN_DEFAULT_BUILD_BEAM,
386 DISKANN_DEFAULT_ALPHA,
387 DISKANN_DEFAULT_PASSES,
388 DISKANN_DEFAULT_EXTRA_SEEDS,
389 dist,
390 file_path,
391 )
392 }
393
394 pub fn build_index_with_params(
396 vectors: &[Vec<T>],
397 dist: D,
398 file_path: &str,
399 p: DiskAnnParams,
400 ) -> Result<Self, DiskAnnError> {
401 Self::build_index(
402 vectors,
403 p.max_degree,
404 p.build_beam_width,
405 p.alpha,
406 p.passes,
407 p.extra_seeds,
408 dist,
409 file_path,
410 )
411 }
412
413 pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
415 let mut file = OpenOptions::new().read(true).write(false).open(path)?;
416
417 let mut buf8 = [0u8; 8];
419 file.seek(SeekFrom::Start(0))?;
420 file.read_exact(&mut buf8)?;
421 let md_len = u64::from_le_bytes(buf8);
422
423 let mut md_bytes = vec![0u8; md_len as usize];
425 file.read_exact(&mut md_bytes)?;
426 let metadata: Metadata = bincode::deserialize(&md_bytes)?;
427
428 let mmap = unsafe { memmap2::Mmap::map(&file)? };
429
430 let want = std::mem::size_of::<T>() as u8;
432 if metadata.elem_size != want {
433 return Err(DiskAnnError::IndexError(format!(
434 "element size mismatch: file has {}B, T is {}B",
435 metadata.elem_size, want
436 )));
437 }
438
439 let expected = std::any::type_name::<D>();
441 if metadata.distance_name != expected {
442 eprintln!(
443 "Warning: index recorded distance `{}` but you opened with `{}`",
444 metadata.distance_name, expected
445 );
446 }
447
448 Ok(Self {
449 dim: metadata.dim,
450 num_vectors: metadata.num_vectors,
451 max_degree: metadata.max_degree,
452 distance_name: metadata.distance_name,
453 medoid_id: metadata.medoid_id,
454 vectors_offset: metadata.vectors_offset,
455 adjacency_offset: metadata.adjacency_offset,
456 mmap,
457 dist,
458 _phantom: PhantomData,
459 })
460 }
461}
462
463impl<T, D> DiskANN<T, D>
465where
466 T: bytemuck::Pod + Copy + Send + Sync + 'static,
467 D: Distance<T> + Default + Send + Sync + Copy + Clone + 'static,
468{
469 pub fn build_index_default_metric(
471 vectors: &[Vec<T>],
472 file_path: &str,
473 ) -> Result<Self, DiskAnnError> {
474 Self::build_index_default(vectors, D::default(), file_path)
475 }
476
477 pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
479 Self::open_index_with(path, D::default())
480 }
481}
482
483impl<T, D> DiskANN<T, D>
484where
485 T: bytemuck::Pod + Copy + Send + Sync + 'static,
486 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
487{
488 pub fn build_index(
500 vectors: &[Vec<T>],
501 max_degree: usize,
502 build_beam_width: usize,
503 alpha: f32,
504 passes: usize,
505 extra_seeds: usize,
506 dist: D,
507 file_path: &str,
508 ) -> Result<Self, DiskAnnError> {
509 let flat = FlatVectors::from_vecs(vectors)?;
510
511 let num_vectors = flat.n;
512 let dim = flat.dim;
513
514 let mut file = OpenOptions::new()
515 .create(true)
516 .write(true)
517 .read(true)
518 .truncate(true)
519 .open(file_path)?;
520
521 let vectors_offset = 1024 * 1024;
523 assert_eq!(
524 (vectors_offset as usize) % std::mem::align_of::<T>(),
525 0,
526 "vectors_offset must be aligned for T"
527 );
528
529 let elem_sz = std::mem::size_of::<T>() as u64;
530 let total_vector_bytes = (num_vectors as u64) * (dim as u64) * elem_sz;
531
532 file.seek(SeekFrom::Start(vectors_offset as u64))?;
534 file.write_all(bytemuck::cast_slice::<T, u8>(&flat.data))?;
535
536 let medoid_id = calculate_medoid(&flat, dist);
538
539 let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
541 let graph = build_vamana_graph(
542 &flat,
543 max_degree,
544 build_beam_width,
545 alpha,
546 passes,
547 extra_seeds,
548 dist,
549 medoid_id as u32,
550 );
551
552 file.seek(SeekFrom::Start(adjacency_offset))?;
554 for neighbors in &graph {
555 let mut padded = neighbors.clone();
556 padded.resize(max_degree, PAD_U32);
557 let bytes = bytemuck::cast_slice::<u32, u8>(&padded);
558 file.write_all(bytes)?;
559 }
560
561 let metadata = Metadata {
563 dim,
564 num_vectors,
565 max_degree,
566 medoid_id: medoid_id as u32,
567 vectors_offset: vectors_offset as u64,
568 adjacency_offset,
569 elem_size: std::mem::size_of::<T>() as u8,
570 distance_name: std::any::type_name::<D>().to_string(),
571 };
572
573 let md_bytes = bincode::serialize(&metadata)?;
574 file.seek(SeekFrom::Start(0))?;
575 let md_len = md_bytes.len() as u64;
576 file.write_all(&md_len.to_le_bytes())?;
577 file.write_all(&md_bytes)?;
578 file.sync_all()?;
579
580 let mmap = unsafe { memmap2::Mmap::map(&file)? };
582
583 Ok(Self {
584 dim,
585 num_vectors,
586 max_degree,
587 distance_name: metadata.distance_name,
588 medoid_id: metadata.medoid_id,
589 vectors_offset: metadata.vectors_offset,
590 adjacency_offset: metadata.adjacency_offset,
591 mmap,
592 dist,
593 _phantom: PhantomData,
594 })
595 }
596
597 pub fn search_with_dists(&self, query: &[T], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
600 assert_eq!(
601 query.len(),
602 self.dim,
603 "Query dim {} != index dim {}",
604 query.len(),
605 self.dim
606 );
607
608 let mut visited = HashSet::new();
609 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
610 let mut w: BinaryHeap<Candidate> = BinaryHeap::new();
611
612 let start_dist = self.distance_to(query, self.medoid_id as usize);
613 let start = Candidate {
614 dist: start_dist,
615 id: self.medoid_id,
616 };
617 frontier.push(Reverse(start));
618 w.push(start);
619 visited.insert(self.medoid_id);
620
621 while let Some(Reverse(best)) = frontier.peek().copied() {
622 if w.len() >= beam_width {
623 if let Some(worst) = w.peek() {
624 if best.dist >= worst.dist {
625 break;
626 }
627 }
628 }
629 let Reverse(current) = frontier.pop().unwrap();
630
631 for &nb in self.get_neighbors(current.id) {
632 if nb == PAD_U32 {
633 continue;
634 }
635 if !visited.insert(nb) {
636 continue;
637 }
638
639 let d = self.distance_to(query, nb as usize);
640 let cand = Candidate { dist: d, id: nb };
641
642 if w.len() < beam_width {
643 w.push(cand);
644 frontier.push(Reverse(cand));
645 } else if d < w.peek().unwrap().dist {
646 w.pop();
647 w.push(cand);
648 frontier.push(Reverse(cand));
649 }
650 }
651 }
652
653 let mut results: Vec<_> = w.into_vec();
654 results.sort_by(|a, b| a.dist.total_cmp(&b.dist));
655 results.truncate(k);
656 results.into_iter().map(|c| (c.id, c.dist)).collect()
657 }
658
659 pub fn search(&self, query: &[T], k: usize, beam_width: usize) -> Vec<u32> {
661 self.search_with_dists(query, k, beam_width)
662 .into_iter()
663 .map(|(id, _dist)| id)
664 .collect()
665 }
666
667 fn get_neighbors(&self, node_id: u32) -> &[u32] {
669 let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
670 let start = offset as usize;
671 let end = start + (self.max_degree * 4);
672 let bytes = &self.mmap[start..end];
673 bytemuck::cast_slice(bytes)
674 }
675
676 fn distance_to(&self, query: &[T], idx: usize) -> f32 {
678 let elem_sz = std::mem::size_of::<T>();
679 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
680 let start = offset as usize;
681 let end = start + (self.dim * elem_sz);
682 let bytes = &self.mmap[start..end];
683 let vector: &[T] = bytemuck::cast_slice(bytes);
684 self.dist.eval(query, vector)
685 }
686
687 pub fn get_vector(&self, idx: usize) -> Vec<T> {
689 let elem_sz = std::mem::size_of::<T>();
690 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
691 let start = offset as usize;
692 let end = start + (self.dim * elem_sz);
693 let bytes = &self.mmap[start..end];
694 let vector: &[T] = bytemuck::cast_slice(bytes);
695 vector.to_vec()
696 }
697}
698
699fn calculate_medoid<T, D>(vectors: &FlatVectors<T>, dist: D) -> usize
701where
702 T: bytemuck::Pod + Copy + Send + Sync,
703 D: Distance<T> + Copy + Sync,
704{
705 let n = vectors.n;
706 let k = 8.min(n);
707 let mut rng = thread_rng();
708 let pivots: Vec<usize> = (0..k).map(|_| rng.gen_range(0..n)).collect();
709
710 let (best_idx, _best_score) = (0..n)
711 .into_par_iter()
712 .map(|i| {
713 let vi = vectors.row(i);
714 let score: f32 = pivots.iter().map(|&p| dist.eval(vi, vectors.row(p))).sum();
715 (i, score)
716 })
717 .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
718
719 best_idx
720}
721
722fn build_vamana_graph<T, D>(
729 vectors: &FlatVectors<T>,
730 max_degree: usize,
731 build_beam_width: usize,
732 alpha: f32,
733 passes: usize,
734 extra_seeds: usize,
735 dist: D,
736 medoid_id: u32,
737) -> Vec<Vec<u32>>
738where
739 T: bytemuck::Pod + Copy + Send + Sync,
740 D: Distance<T> + Copy + Sync,
741{
742 let n = vectors.n;
743 let mut graph = vec![Vec::<u32>::new(); n];
744
745 {
747 let mut rng = thread_rng();
748 let target = max_degree.min(n.saturating_sub(1));
749
750 for i in 0..n {
751 let mut s = HashSet::with_capacity(target);
752 while s.len() < target {
753 let nb = rng.gen_range(0..n);
754 if nb != i {
755 s.insert(nb as u32);
756 }
757 }
758 graph[i] = s.into_iter().collect();
759 }
760 }
761
762 let passes = passes.max(1);
763 let mut rng = thread_rng();
764
765 for pass_idx in 0..passes {
766 let pass_alpha = if passes == 1 {
767 alpha
768 } else if pass_idx == 0 {
769 1.0
770 } else {
771 alpha
772 };
773
774 let mut order: Vec<usize> = (0..n).collect();
775 order.shuffle(&mut rng);
776
777 let snapshot = &graph;
778
779 let new_graph: Vec<Vec<u32>> = order
780 .par_iter()
781 .map_init(
782 || BuildScratch::new(n, build_beam_width, max_degree, extra_seeds),
783 |scratch, &u| {
784 scratch.candidates.clear();
785
786 for &nb in &snapshot[u] {
788 let d = dist.eval(vectors.row(u), vectors.row(nb as usize));
789 scratch.candidates.push((nb, d));
790 }
791
792 scratch.seeds.clear();
794 scratch.seeds.push(medoid_id as usize);
795 let mut trng = thread_rng();
796 while scratch.seeds.len() < 1 + extra_seeds {
797 let s = trng.gen_range(0..n);
798 if !scratch.seeds.contains(&s) {
799 scratch.seeds.push(s);
800 }
801 }
802
803 let seeds = scratch.seeds.clone();
805 for start in seeds {
806 greedy_search_visited_collect(
807 vectors.row(u),
808 vectors,
809 snapshot,
810 start,
811 build_beam_width,
812 dist,
813 scratch,
814 );
815
816 for i in 0..scratch.visited_ids.len() {
817 scratch
818 .candidates
819 .push((scratch.visited_ids[i], scratch.visited_dists[i]));
820 }
821 }
822
823 scratch.candidates.sort_by(|a, b| a.0.cmp(&b.0));
825 scratch.candidates.dedup_by(|a, b| {
826 if a.0 == b.0 {
827 if a.1 < b.1 {
828 *b = *a;
829 }
830 true
831 } else {
832 false
833 }
834 });
835
836 prune_neighbors(
837 u,
838 &scratch.candidates,
839 vectors,
840 max_degree,
841 pass_alpha,
842 dist,
843 )
844 },
845 )
846 .collect();
847
848 let mut pos_of = vec![0usize; n];
850 for (pos, &u) in order.iter().enumerate() {
851 pos_of[u] = pos;
852 }
853
854 let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
855
856 graph = (0..n)
857 .into_par_iter()
858 .map(|u| {
859 let ng = &new_graph[pos_of[u]];
860 let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]];
861
862 let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
863 pool_ids.extend_from_slice(ng);
864 pool_ids.extend_from_slice(inc);
865 pool_ids.sort_unstable();
866 pool_ids.dedup();
867
868 let pool: Vec<(u32, f32)> = pool_ids
869 .into_iter()
870 .filter(|&id| id as usize != u)
871 .map(|id| (id, dist.eval(vectors.row(u), vectors.row(id as usize))))
872 .collect();
873
874 prune_neighbors(u, &pool, vectors, max_degree, pass_alpha, dist)
875 })
876 .collect();
877 }
878
879 graph
881 .into_par_iter()
882 .enumerate()
883 .map(|(u, neigh)| {
884 if neigh.len() <= max_degree {
885 return neigh;
886 }
887 let pool: Vec<(u32, f32)> = neigh
888 .iter()
889 .map(|&id| (id, dist.eval(vectors.row(u), vectors.row(id as usize))))
890 .collect();
891 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
892 })
893 .collect()
894}
895
896fn greedy_search_visited_collect<T, D>(
903 query: &[T],
904 vectors: &FlatVectors<T>,
905 graph: &[Vec<u32>],
906 start_id: usize,
907 beam_width: usize,
908 dist: D,
909 scratch: &mut BuildScratch,
910) where
911 T: bytemuck::Pod + Copy + Send + Sync,
912 D: Distance<T> + Copy,
913{
914 scratch.reset_search();
915
916 let start_dist = dist.eval(query, vectors.row(start_id));
917 let start = Candidate {
918 dist: start_dist,
919 id: start_id as u32,
920 };
921
922 scratch.frontier.insert_unbounded(start);
923 scratch.work.insert_capped(start, beam_width);
924 scratch.mark_with_dist(start_id, start_dist);
925
926 while !scratch.frontier.is_empty() {
927 let best = scratch.frontier.best().unwrap();
928 if scratch.work.len() >= beam_width {
929 if let Some(worst) = scratch.work.worst() {
930 if best.dist >= worst.dist {
931 break;
932 }
933 }
934 }
935
936 let cur = scratch.frontier.pop_best().unwrap();
937
938 for &nb in &graph[cur.id as usize] {
939 let nb_usize = nb as usize;
940 if scratch.is_marked(nb_usize) {
941 continue;
942 }
943
944 let d = dist.eval(query, vectors.row(nb_usize));
945 scratch.mark_with_dist(nb_usize, d);
946
947 let cand = Candidate { dist: d, id: nb };
948
949 if scratch.work.len() < beam_width {
950 scratch.work.insert_unbounded(cand);
951 scratch.frontier.insert_unbounded(cand);
952 } else if let Some(worst) = scratch.work.worst() {
953 if d < worst.dist {
954 scratch.work.insert_capped(cand, beam_width);
955 scratch.frontier.insert_unbounded(cand);
956 }
957 }
958 }
959 }
960}
961
962fn prune_neighbors<T, D>(
964 node_id: usize,
965 candidates: &[(u32, f32)],
966 vectors: &FlatVectors<T>,
967 max_degree: usize,
968 alpha: f32,
969 dist: D,
970) -> Vec<u32>
971where
972 T: bytemuck::Pod + Copy + Send + Sync,
973 D: Distance<T> + Copy,
974{
975 if candidates.is_empty() {
976 return Vec::new();
977 }
978
979 let mut sorted = candidates.to_vec();
980 sorted.sort_by(|a, b| a.1.total_cmp(&b.1));
981
982 let mut pruned = Vec::<u32>::new();
983
984 for &(cand_id, cand_dist) in &sorted {
985 if cand_id as usize == node_id {
986 continue;
987 }
988 let mut ok = true;
989 for &sel in &pruned {
990 let d = dist.eval(vectors.row(cand_id as usize), vectors.row(sel as usize));
991 if alpha * d <= cand_dist {
992 ok = false;
993 break;
994 }
995 }
996 if ok {
997 pruned.push(cand_id);
998 if pruned.len() >= max_degree {
999 break;
1000 }
1001 }
1002 }
1003
1004 for &(cand_id, _) in &sorted {
1005 if cand_id as usize == node_id {
1006 continue;
1007 }
1008 if !pruned.contains(&cand_id) {
1009 pruned.push(cand_id);
1010 if pruned.len() >= max_degree {
1011 break;
1012 }
1013 }
1014 }
1015
1016 pruned
1017}
1018
1019fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
1020 let mut indeg = vec![0usize; n];
1021 for (pos, _u) in order.iter().enumerate() {
1022 for &v in &new_graph[pos] {
1023 indeg[v as usize] += 1;
1024 }
1025 }
1026
1027 let mut off = vec![0usize; n + 1];
1028 for i in 0..n {
1029 off[i + 1] = off[i] + indeg[i];
1030 }
1031
1032 let mut cur = off.clone();
1033 let mut incoming_flat = vec![0u32; off[n]];
1034 for (pos, &u) in order.iter().enumerate() {
1035 for &v in &new_graph[pos] {
1036 let idx = cur[v as usize];
1037 incoming_flat[idx] = u as u32;
1038 cur[v as usize] += 1;
1039 }
1040 }
1041 (incoming_flat, off)
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046 use super::*;
1047 use anndists::dist::{DistCosine, DistL2};
1048 use rand::Rng;
1049 use std::fs;
1050
1051 fn euclid(a: &[f32], b: &[f32]) -> f32 {
1052 a.iter()
1053 .zip(b)
1054 .map(|(x, y)| (x - y) * (x - y))
1055 .sum::<f32>()
1056 .sqrt()
1057 }
1058
1059 #[test]
1060 fn test_small_index_l2() {
1061 let path = "test_small_l2.db";
1062 let _ = fs::remove_file(path);
1063
1064 let vectors = vec![
1065 vec![0.0, 0.0],
1066 vec![1.0, 0.0],
1067 vec![0.0, 1.0],
1068 vec![1.0, 1.0],
1069 vec![0.5, 0.5],
1070 ];
1071
1072 let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1073
1074 let q = vec![0.1, 0.1];
1075 let nns = index.search(&q, 3, 8);
1076 assert_eq!(nns.len(), 3);
1077
1078 let v = index.get_vector(nns[0] as usize);
1079 assert!(euclid(&q, &v) < 1.0);
1080
1081 let _ = fs::remove_file(path);
1082 }
1083
1084 #[test]
1085 fn test_cosine() {
1086 let path = "test_cosine.db";
1087 let _ = fs::remove_file(path);
1088
1089 let vectors = vec![
1090 vec![1.0, 0.0, 0.0],
1091 vec![0.0, 1.0, 0.0],
1092 vec![0.0, 0.0, 1.0],
1093 vec![1.0, 1.0, 0.0],
1094 vec![1.0, 0.0, 1.0],
1095 ];
1096
1097 let index =
1098 DiskANN::<f32, DistCosine>::build_index_default(&vectors, DistCosine, path).unwrap();
1099
1100 let q = vec![2.0, 0.0, 0.0];
1101 let nns = index.search(&q, 2, 8);
1102 assert_eq!(nns.len(), 2);
1103
1104 let v = index.get_vector(nns[0] as usize);
1105 let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
1106 let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1107 let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
1108 let cos = dot / (n1 * n2);
1109 assert!(cos > 0.7);
1110
1111 let _ = fs::remove_file(path);
1112 }
1113
1114 #[test]
1115 fn test_persistence_and_open() {
1116 let path = "test_persist.db";
1117 let _ = fs::remove_file(path);
1118
1119 let vectors = vec![
1120 vec![0.0, 0.0],
1121 vec![1.0, 0.0],
1122 vec![0.0, 1.0],
1123 vec![1.0, 1.0],
1124 ];
1125
1126 {
1127 let _idx =
1128 DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1129 }
1130
1131 let idx2 = DiskANN::<f32, DistL2>::open_index_default_metric(path).unwrap();
1132 assert_eq!(idx2.num_vectors, 4);
1133 assert_eq!(idx2.dim, 2);
1134
1135 let q = vec![0.9, 0.9];
1136 let res = idx2.search(&q, 2, 8);
1137 assert_eq!(res[0], 3);
1138
1139 let _ = fs::remove_file(path);
1140 }
1141
1142 #[test]
1143 fn test_grid_connectivity() {
1144 let path = "test_grid.db";
1145 let _ = fs::remove_file(path);
1146
1147 let mut vectors = Vec::new();
1148 for i in 0..5 {
1149 for j in 0..5 {
1150 vectors.push(vec![i as f32, j as f32]);
1151 }
1152 }
1153
1154 let index = DiskANN::<f32, DistL2>::build_index_with_params(
1155 &vectors,
1156 DistL2,
1157 path,
1158 DiskAnnParams {
1159 max_degree: 4,
1160 build_beam_width: 64,
1161 alpha: 1.5,
1162 passes: DISKANN_DEFAULT_PASSES,
1163 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1164 },
1165 )
1166 .unwrap();
1167
1168 for target in 0..vectors.len() {
1169 let q = &vectors[target];
1170 let nns = index.search(q, 10, 32);
1171 if !nns.contains(&(target as u32)) {
1172 let v = index.get_vector(nns[0] as usize);
1173 assert!(euclid(q, &v) < 2.0);
1174 }
1175 for &nb in nns.iter().take(5) {
1176 let v = index.get_vector(nb as usize);
1177 assert!(euclid(q, &v) < 5.0);
1178 }
1179 }
1180
1181 let _ = fs::remove_file(path);
1182 }
1183
1184 #[test]
1185 fn test_medium_random() {
1186 let path = "test_medium.db";
1187 let _ = fs::remove_file(path);
1188
1189 let n = 200usize;
1190 let d = 32usize;
1191 let mut rng = rand::thread_rng();
1192 let vectors: Vec<Vec<f32>> = (0..n)
1193 .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
1194 .collect();
1195
1196 let index = DiskANN::<f32, DistL2>::build_index_with_params(
1197 &vectors,
1198 DistL2,
1199 path,
1200 DiskAnnParams {
1201 max_degree: 32,
1202 build_beam_width: 128,
1203 alpha: 1.2,
1204 passes: DISKANN_DEFAULT_PASSES,
1205 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1206 },
1207 )
1208 .unwrap();
1209
1210 let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1211 let res = index.search(&q, 10, 64);
1212 assert_eq!(res.len(), 10);
1213
1214 let dists: Vec<f32> = res
1215 .iter()
1216 .map(|&id| {
1217 let v = index.get_vector(id as usize);
1218 euclid(&q, &v)
1219 })
1220 .collect();
1221 let mut sorted = dists.clone();
1222 sorted.sort_by(|a, b| a.total_cmp(b));
1223 assert_eq!(dists, sorted);
1224
1225 let _ = fs::remove_file(path);
1226 }
1227}