1use anndists::prelude::Distance;
41use bytemuck;
42use memmap2::Mmap;
43use rand::prelude::*;
44use rayon::prelude::*;
45use serde::{Deserialize, Serialize};
46use std::cmp::{Ordering, Reverse};
47use std::collections::{BinaryHeap, HashSet};
48use std::fs::OpenOptions;
49use std::io::{Read, Seek, SeekFrom, Write};
50use thiserror::Error;
51
52const PAD_U32: u32 = u32::MAX;
54
55pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
57pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
58pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
59
60#[derive(Clone, Copy, Debug)]
62pub struct DiskAnnParams {
63 pub max_degree: usize,
64 pub build_beam_width: usize,
65 pub alpha: f32,
66}
67impl Default for DiskAnnParams {
68 fn default() -> Self {
69 Self {
70 max_degree: DISKANN_DEFAULT_MAX_DEGREE,
71 build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
72 alpha: DISKANN_DEFAULT_ALPHA,
73 }
74 }
75}
76
77#[derive(Debug, Error)]
79pub enum DiskAnnError {
80 #[error("I/O error: {0}")]
82 Io(#[from] std::io::Error),
83
84 #[error("Serialization error: {0}")]
86 Bincode(#[from] bincode::Error),
87
88 #[error("Index error: {0}")]
90 IndexError(String),
91}
92
93#[derive(Serialize, Deserialize, Debug)]
95struct Metadata {
96 dim: usize,
97 num_vectors: usize,
98 max_degree: usize,
99 medoid_id: u32,
100 vectors_offset: u64,
101 adjacency_offset: u64,
102 distance_name: String,
103}
104
105#[derive(Clone, Copy)]
107struct Candidate {
108 dist: f32,
109 id: u32,
110}
111impl PartialEq for Candidate {
112 fn eq(&self, other: &Self) -> bool {
113 self.dist == other.dist && self.id == other.id
114 }
115}
116impl Eq for Candidate {}
117impl PartialOrd for Candidate {
118 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
119 self.dist.partial_cmp(&other.dist)
121 }
122}
123impl Ord for Candidate {
124 fn cmp(&self, other: &Self) -> Ordering {
125 self.partial_cmp(other).unwrap_or(Ordering::Equal)
126 }
127}
128
129pub struct DiskANN<D>
131where
132 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
133{
134 pub dim: usize,
136 pub num_vectors: usize,
138 pub max_degree: usize,
140 pub distance_name: String,
142
143 medoid_id: u32,
145 vectors_offset: u64,
147 adjacency_offset: u64,
148
149 mmap: Mmap,
151
152 dist: D,
154}
155
156impl<D> DiskANN<D>
159where
160 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
161{
162 pub fn build_index_default(
164 vectors: &[Vec<f32>],
165 dist: D,
166 file_path: &str,
167 ) -> Result<Self, DiskAnnError> {
168 Self::build_index(
169 vectors,
170 DISKANN_DEFAULT_MAX_DEGREE,
171 DISKANN_DEFAULT_BUILD_BEAM,
172 DISKANN_DEFAULT_ALPHA,
173 dist,
174 file_path,
175 )
176 }
177
178 pub fn build_index_with_params(
180 vectors: &[Vec<f32>],
181 dist: D,
182 file_path: &str,
183 p: DiskAnnParams,
184 ) -> Result<Self, DiskAnnError> {
185 Self::build_index(
186 vectors,
187 p.max_degree,
188 p.build_beam_width,
189 p.alpha,
190 dist,
191 file_path,
192 )
193 }
194}
195
196impl<D> DiskANN<D>
198where
199 D: Distance<f32> + Default + Send + Sync + Copy + Clone + 'static,
200{
201 pub fn build_index_default_metric(
203 vectors: &[Vec<f32>],
204 file_path: &str,
205 ) -> Result<Self, DiskAnnError> {
206 Self::build_index_default(vectors, D::default(), file_path)
207 }
208
209 pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
211 Self::open_index_with(path, D::default())
212 }
213}
214
215impl<D> DiskANN<D>
216where
217 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
218{
219 pub fn build_index(
229 vectors: &[Vec<f32>],
230 max_degree: usize,
231 build_beam_width: usize,
232 alpha: f32,
233 dist: D,
234 file_path: &str,
235 ) -> Result<Self, DiskAnnError> {
236 if vectors.is_empty() {
237 return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
238 }
239
240 let num_vectors = vectors.len();
241 let dim = vectors[0].len();
242 for (i, v) in vectors.iter().enumerate() {
243 if v.len() != dim {
244 return Err(DiskAnnError::IndexError(format!(
245 "Vector {} has dimension {} but expected {}",
246 i,
247 v.len(),
248 dim
249 )));
250 }
251 }
252
253 let mut file = OpenOptions::new()
254 .create(true)
255 .write(true)
256 .read(true)
257 .truncate(true)
258 .open(file_path)?;
259
260 let vectors_offset = 1024 * 1024;
262 let total_vector_bytes = (num_vectors as u64) * (dim as u64) * 4;
263
264 file.seek(SeekFrom::Start(vectors_offset))?;
266 for vector in vectors {
267 let bytes = bytemuck::cast_slice(vector);
268 file.write_all(bytes)?;
269 }
270
271 let medoid_id = calculate_medoid(vectors, dist);
273
274 let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
276 let graph = build_vamana_graph(
277 vectors,
278 max_degree,
279 build_beam_width,
280 alpha,
281 dist,
282 medoid_id as u32,
283 );
284
285 file.seek(SeekFrom::Start(adjacency_offset))?;
287 for neighbors in &graph {
288 let mut padded = neighbors.clone();
289 padded.resize(max_degree, PAD_U32);
290 let bytes = bytemuck::cast_slice(&padded);
291 file.write_all(bytes)?;
292 }
293
294 let metadata = Metadata {
296 dim,
297 num_vectors,
298 max_degree,
299 medoid_id: medoid_id as u32,
300 vectors_offset: vectors_offset as u64,
301 adjacency_offset,
302 distance_name: std::any::type_name::<D>().to_string(),
303 };
304
305 let md_bytes = bincode::serialize(&metadata)?;
306 file.seek(SeekFrom::Start(0))?;
307 let md_len = md_bytes.len() as u64;
308 file.write_all(&md_len.to_le_bytes())?;
309 file.write_all(&md_bytes)?;
310 file.sync_all()?;
311
312 let mmap = unsafe { memmap2::Mmap::map(&file)? };
314
315 Ok(Self {
316 dim,
317 num_vectors,
318 max_degree,
319 distance_name: metadata.distance_name,
320 medoid_id: metadata.medoid_id,
321 vectors_offset: metadata.vectors_offset,
322 adjacency_offset: metadata.adjacency_offset,
323 mmap,
324 dist,
325 })
326 }
327
328 pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
330 let mut file = OpenOptions::new().read(true).write(false).open(path)?;
331
332 let mut buf8 = [0u8; 8];
334 file.seek(SeekFrom::Start(0))?;
335 file.read_exact(&mut buf8)?;
336 let md_len = u64::from_le_bytes(buf8);
337
338 let mut md_bytes = vec![0u8; md_len as usize];
340 file.read_exact(&mut md_bytes)?;
341 let metadata: Metadata = bincode::deserialize(&md_bytes)?;
342
343 let mmap = unsafe { memmap2::Mmap::map(&file)? };
344
345 let expected = std::any::type_name::<D>();
347 if metadata.distance_name != expected {
348 eprintln!(
349 "Warning: index recorded distance `{}` but you opened with `{}`",
350 metadata.distance_name, expected
351 );
352 }
353
354 Ok(Self {
355 dim: metadata.dim,
356 num_vectors: metadata.num_vectors,
357 max_degree: metadata.max_degree,
358 distance_name: metadata.distance_name,
359 medoid_id: metadata.medoid_id,
360 vectors_offset: metadata.vectors_offset,
361 adjacency_offset: metadata.adjacency_offset,
362 mmap,
363 dist,
364 })
365 }
366
367 pub fn search_with_dists(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
372 assert_eq!(
373 query.len(),
374 self.dim,
375 "Query dim {} != index dim {}",
376 query.len(),
377 self.dim
378 );
379
380 #[derive(Clone, Copy)]
381 struct Candidate {
382 dist: f32,
383 id: u32,
384 }
385 impl PartialEq for Candidate {
386 fn eq(&self, o: &Self) -> bool {
387 self.dist == o.dist && self.id == o.id
388 }
389 }
390 impl Eq for Candidate {}
391 impl PartialOrd for Candidate {
392 fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
393 self.dist.partial_cmp(&o.dist)
394 }
395 }
396 impl Ord for Candidate {
397 fn cmp(&self, o: &Self) -> Ordering {
398 self.partial_cmp(o).unwrap_or(Ordering::Equal)
399 }
400 }
401
402 let mut visited = HashSet::new();
403 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); let start_dist = self.distance_to(query, self.medoid_id as usize);
408 let start = Candidate {
409 dist: start_dist,
410 id: self.medoid_id,
411 };
412 frontier.push(Reverse(start));
413 w.push(start);
414 visited.insert(self.medoid_id);
415
416 while let Some(Reverse(best)) = frontier.peek().copied() {
418 if w.len() >= beam_width {
419 if let Some(worst) = w.peek() {
420 if best.dist >= worst.dist {
421 break;
422 }
423 }
424 }
425 let Reverse(current) = frontier.pop().unwrap();
426
427 for &nb in self.get_neighbors(current.id) {
428 if nb == PAD_U32 {
429 continue;
430 }
431 if !visited.insert(nb) {
432 continue;
433 }
434
435 let d = self.distance_to(query, nb as usize);
436 let cand = Candidate { dist: d, id: nb };
437
438 if w.len() < beam_width {
439 w.push(cand);
440 frontier.push(Reverse(cand));
441 } else if d < w.peek().unwrap().dist {
442 w.pop();
443 w.push(cand);
444 frontier.push(Reverse(cand));
445 }
446 }
447 }
448
449 let mut results: Vec<_> = w.into_vec();
451 results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
452 results.truncate(k);
453 results.into_iter().map(|c| (c.id, c.dist)).collect()
454 }
455 pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
457 self.search_with_dists(query, k, beam_width)
458 .into_iter()
459 .map(|(id, _dist)| id)
460 .collect()
461 }
462
463 fn get_neighbors(&self, node_id: u32) -> &[u32] {
465 let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
466 let start = offset as usize;
467 let end = start + (self.max_degree * 4);
468 let bytes = &self.mmap[start..end];
469 bytemuck::cast_slice(bytes)
470 }
471
472 fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
474 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
475 let start = offset as usize;
476 let end = start + (self.dim * 4);
477 let bytes = &self.mmap[start..end];
478 let vector: &[f32] = bytemuck::cast_slice(bytes);
479 self.dist.eval(query, vector)
480 }
481
482 pub fn get_vector(&self, idx: usize) -> Vec<f32> {
484 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
485 let start = offset as usize;
486 let end = start + (self.dim * 4);
487 let bytes = &self.mmap[start..end];
488 let vector: &[f32] = bytemuck::cast_slice(bytes);
489 vector.to_vec()
490 }
491}
492
493fn calculate_medoid<D: Distance<f32> + Copy + Sync>(vectors: &[Vec<f32>], dist: D) -> usize {
496 let dim = vectors[0].len();
497 let mut centroid = vec![0.0f32; dim];
498
499 for v in vectors {
500 for (i, &val) in v.iter().enumerate() {
501 centroid[i] += val;
502 }
503 }
504 for val in &mut centroid {
505 *val /= vectors.len() as f32;
506 }
507
508 let (best_idx, _best_dist) = vectors
509 .par_iter()
510 .enumerate()
511 .map(|(idx, v)| (idx, dist.eval(¢roid, v)))
512 .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
513
514 best_idx
515}
516
517fn build_vamana_graph<D: Distance<f32> + Copy + Sync>(
522 vectors: &[Vec<f32>],
523 max_degree: usize,
524 build_beam_width: usize,
525 alpha: f32,
526 dist: D,
527 medoid_id: u32,
528) -> Vec<Vec<u32>> {
529 let n = vectors.len();
530 let mut graph = vec![Vec::<u32>::new(); n];
531
532 {
534 let mut rng = thread_rng();
535 for i in 0..n {
536 let mut s = HashSet::new();
537 let target = (max_degree / 2).max(2).min(n.saturating_sub(1));
538 while s.len() < target {
539 let nb = rng.gen_range(0..n);
540 if nb != i {
541 s.insert(nb as u32);
542 }
543 }
544 graph[i] = s.into_iter().collect();
545 }
546 }
547
548 const PASSES: usize = 2;
550 const EXTRA_SEEDS: usize = 2;
551
552 let mut rng = thread_rng();
553 for _pass in 0..PASSES {
555 let mut order: Vec<usize> = (0..n).collect();
557 order.shuffle(&mut rng);
558
559 let snapshot = &graph;
561
562 let new_graph: Vec<Vec<u32>> = order
564 .par_iter()
565 .map(|&u| {
566 let mut candidates: Vec<(u32, f32)> =
568 Vec::with_capacity(build_beam_width * (2 + EXTRA_SEEDS));
569
570 for &nb in &snapshot[u] {
572 let d = dist.eval(&vectors[u], &vectors[nb as usize]);
573 candidates.push((nb, d));
574 }
575
576 let mut seeds = Vec::with_capacity(1 + EXTRA_SEEDS);
578 seeds.push(medoid_id as usize);
579 let mut trng = thread_rng();
580 for _ in 0..EXTRA_SEEDS {
581 seeds.push(trng.gen_range(0..n));
582 }
583
584 for start in seeds {
586 let mut part = greedy_search(
587 &vectors[u],
588 vectors,
589 snapshot,
590 start,
591 build_beam_width,
592 dist,
593 );
594 candidates.append(&mut part);
595 }
596
597 candidates.sort_by(|a, b| a.0.cmp(&b.0));
599 candidates.dedup_by(|a, b| {
600 if a.0 == b.0 {
601 if a.1 < b.1 {
602 *b = *a;
603 }
604 true
605 } else {
606 false
607 }
608 });
609
610 prune_neighbors(u, &candidates, vectors, max_degree, alpha, dist)
612 })
613 .collect();
614
615 let mut pos_of = vec![0usize; n];
618 for (pos, &u) in order.iter().enumerate() {
619 pos_of[u] = pos;
620 }
621
622 let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
624
625 graph = (0..n)
627 .into_par_iter()
628 .map(|u| {
629 let ng = &new_graph[pos_of[u]]; let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]]; let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
634 pool_ids.extend_from_slice(ng);
635 pool_ids.extend_from_slice(inc);
636 pool_ids.sort_unstable();
637 pool_ids.dedup();
638
639 let pool: Vec<(u32, f32)> = pool_ids
641 .into_iter()
642 .filter(|&id| id as usize != u)
643 .map(|id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
644 .collect();
645
646 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
647 })
648 .collect();
649 }
650
651 graph
653 .into_par_iter()
654 .enumerate()
655 .map(|(u, neigh)| {
656 if neigh.len() <= max_degree {
657 return neigh;
658 }
659 let pool: Vec<(u32, f32)> = neigh
660 .iter()
661 .map(|&id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
662 .collect();
663 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
664 })
665 .collect()
666}
667
668fn greedy_search<D: Distance<f32> + Copy>(
671 query: &[f32],
672 vectors: &[Vec<f32>],
673 graph: &[Vec<u32>],
674 start_id: usize,
675 beam_width: usize,
676 dist: D,
677) -> Vec<(u32, f32)> {
678 let mut visited = HashSet::new();
679 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); let start_dist = dist.eval(query, &vectors[start_id]);
683 let start = Candidate {
684 dist: start_dist,
685 id: start_id as u32,
686 };
687 frontier.push(Reverse(start));
688 w.push(start);
689 visited.insert(start_id as u32);
690
691 while let Some(Reverse(best)) = frontier.peek().copied() {
692 if w.len() >= beam_width {
693 if let Some(worst) = w.peek() {
694 if best.dist >= worst.dist {
695 break;
696 }
697 }
698 }
699 let Reverse(cur) = frontier.pop().unwrap();
700
701 for &nb in &graph[cur.id as usize] {
702 if !visited.insert(nb) {
703 continue;
704 }
705 let d = dist.eval(query, &vectors[nb as usize]);
706 let cand = Candidate { dist: d, id: nb };
707
708 if w.len() < beam_width {
709 w.push(cand);
710 frontier.push(Reverse(cand));
711 } else if d < w.peek().unwrap().dist {
712 w.pop();
713 w.push(cand);
714 frontier.push(Reverse(cand));
715 }
716 }
717 }
718
719 let mut v = w.into_vec();
720 v.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
721 v.into_iter().map(|c| (c.id, c.dist)).collect()
722}
723
724fn prune_neighbors<D: Distance<f32> + Copy>(
726 node_id: usize,
727 candidates: &[(u32, f32)],
728 vectors: &[Vec<f32>],
729 max_degree: usize,
730 alpha: f32,
731 dist: D,
732) -> Vec<u32> {
733 if candidates.is_empty() {
734 return Vec::new();
735 }
736
737 let mut sorted = candidates.to_vec();
738 sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
739
740 let mut pruned = Vec::<u32>::new();
741
742 for &(cand_id, cand_dist) in &sorted {
743 if cand_id as usize == node_id {
744 continue;
745 }
746 let mut ok = true;
747 for &sel in &pruned {
748 let d = dist.eval(&vectors[cand_id as usize], &vectors[sel as usize]);
749 if d < alpha * cand_dist {
750 ok = false;
751 break;
752 }
753 }
754 if ok {
755 pruned.push(cand_id);
756 if pruned.len() >= max_degree {
757 break;
758 }
759 }
760 }
761
762 for &(cand_id, _) in &sorted {
764 if cand_id as usize == node_id {
765 continue;
766 }
767 if !pruned.contains(&cand_id) {
768 pruned.push(cand_id);
769 if pruned.len() >= max_degree {
770 break;
771 }
772 }
773 }
774
775 pruned
776}
777
778fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
779 let mut indeg = vec![0usize; n];
781 for (pos, _u) in order.iter().enumerate() {
782 for &v in &new_graph[pos] {
783 indeg[v as usize] += 1;
784 }
785 }
786 let mut off = vec![0usize; n + 1];
788 for i in 0..n {
789 off[i + 1] = off[i] + indeg[i];
790 }
791 let mut cur = off.clone();
793 let mut incoming_flat = vec![0u32; off[n]];
794 for (pos, &u) in order.iter().enumerate() {
795 for &v in &new_graph[pos] {
796 let idx = cur[v as usize];
797 incoming_flat[idx] = u as u32;
798 cur[v as usize] += 1;
799 }
800 }
801 (incoming_flat, off)
802}
803
804#[cfg(test)]
805mod tests {
806 use super::*;
807 use anndists::dist::{DistCosine, DistL2};
808 use rand::Rng;
809 use std::fs;
810
811 fn euclid(a: &[f32], b: &[f32]) -> f32 {
812 a.iter()
813 .zip(b)
814 .map(|(x, y)| (x - y) * (x - y))
815 .sum::<f32>()
816 .sqrt()
817 }
818
819 #[test]
820 fn test_small_index_l2() {
821 let path = "test_small_l2.db";
822 let _ = fs::remove_file(path);
823
824 let vectors = vec![
825 vec![0.0, 0.0],
826 vec![1.0, 0.0],
827 vec![0.0, 1.0],
828 vec![1.0, 1.0],
829 vec![0.5, 0.5],
830 ];
831
832 let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
833
834 let q = vec![0.1, 0.1];
835 let nns = index.search(&q, 3, 8);
836 assert_eq!(nns.len(), 3);
837
838 let v = index.get_vector(nns[0] as usize);
840 assert!(euclid(&q, &v) < 1.0);
841
842 let _ = fs::remove_file(path);
843 }
844
845 #[test]
846 fn test_cosine() {
847 let path = "test_cosine.db";
848 let _ = fs::remove_file(path);
849
850 let vectors = vec![
851 vec![1.0, 0.0, 0.0],
852 vec![0.0, 1.0, 0.0],
853 vec![0.0, 0.0, 1.0],
854 vec![1.0, 1.0, 0.0],
855 vec![1.0, 0.0, 1.0],
856 ];
857
858 let index =
859 DiskANN::<DistCosine>::build_index_default(&vectors, DistCosine {}, path).unwrap();
860
861 let q = vec![2.0, 0.0, 0.0]; let nns = index.search(&q, 2, 8);
863 assert_eq!(nns.len(), 2);
864
865 let v = index.get_vector(nns[0] as usize);
867 let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
868 let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
869 let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
870 let cos = dot / (n1 * n2);
871 assert!(cos > 0.7);
872
873 let _ = fs::remove_file(path);
874 }
875
876 #[test]
877 fn test_persistence_and_open() {
878 let path = "test_persist.db";
879 let _ = fs::remove_file(path);
880
881 let vectors = vec![
882 vec![0.0, 0.0],
883 vec![1.0, 0.0],
884 vec![0.0, 1.0],
885 vec![1.0, 1.0],
886 ];
887
888 {
889 let _idx = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
890 }
891
892 let idx2 = DiskANN::<DistL2>::open_index_default_metric(path).unwrap();
893 assert_eq!(idx2.num_vectors, 4);
894 assert_eq!(idx2.dim, 2);
895
896 let q = vec![0.9, 0.9];
897 let res = idx2.search(&q, 2, 8);
898 assert_eq!(res[0], 3);
900
901 let _ = fs::remove_file(path);
902 }
903
904 #[test]
905 fn test_grid_connectivity() {
906 let path = "test_grid.db";
907 let _ = fs::remove_file(path);
908
909 let mut vectors = Vec::new();
911 for i in 0..5 {
912 for j in 0..5 {
913 vectors.push(vec![i as f32, j as f32]);
914 }
915 }
916
917 let index = DiskANN::<DistL2>::build_index_with_params(
918 &vectors,
919 DistL2 {},
920 path,
921 DiskAnnParams {
922 max_degree: 4,
923 build_beam_width: 64,
924 alpha: 1.5,
925 },
926 )
927 .unwrap();
928
929 for target in 0..vectors.len() {
930 let q = &vectors[target];
931 let nns = index.search(q, 10, 32);
932 if !nns.contains(&(target as u32)) {
933 let v = index.get_vector(nns[0] as usize);
934 assert!(euclid(q, &v) < 2.0);
935 }
936 for &nb in nns.iter().take(5) {
937 let v = index.get_vector(nb as usize);
938 assert!(euclid(q, &v) < 5.0);
939 }
940 }
941
942 let _ = fs::remove_file(path);
943 }
944
945 #[test]
946 fn test_medium_random() {
947 let path = "test_medium.db";
948 let _ = fs::remove_file(path);
949
950 let n = 200usize;
951 let d = 32usize;
952 let mut rng = rand::thread_rng();
953 let vectors: Vec<Vec<f32>> = (0..n)
954 .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
955 .collect();
956
957 let index = DiskANN::<DistL2>::build_index_with_params(
958 &vectors,
959 DistL2 {},
960 path,
961 DiskAnnParams {
962 max_degree: 32,
963 build_beam_width: 128,
964 alpha: 1.2,
965 },
966 )
967 .unwrap();
968
969 let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
970 let res = index.search(&q, 10, 64);
971 assert_eq!(res.len(), 10);
972
973 let dists: Vec<f32> = res
975 .iter()
976 .map(|&id| {
977 let v = index.get_vector(id as usize);
978 euclid(&q, &v)
979 })
980 .collect();
981 let mut sorted = dists.clone();
982 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
983 assert_eq!(dists, sorted);
984
985 let _ = fs::remove_file(path);
986 }
987}