Skip to main content

rust_diskann/
lib.rs

1//! # DiskAnn (generic over `anndists::Distance<T>`)
2//!
3//! An on-disk DiskANN library that:
4//! - Builds a Vamana-style graph (greedy + α-pruning) in memory
5//! - Writes vectors + fixed-degree adjacency to a single file
6//! - Memory-maps the file for low-overhead reads
7//! - Is **generic over any Distance<T>** from `anndists` (e.g. L2 on `f32`, Cosine on `f32`,
8//!   Hamming on `u64`, …)
9//!
10//! ## Example (f32 + L2)
11//! ```no_run
12//! use anndists::dist::DistL2;
13//! use rust_diskann::{DiskANN, DiskAnnParams};
14//!
15//! let vectors: Vec<Vec<f32>> = vec![vec![0.0; 128]; 1000];
16//! let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, "index.db").unwrap();
17//!
18//! let q = vec![0.0; 128];
19//! let nns = index.search(&q, 10, 64);
20//! ```
21//!
22//! ## Example (u64 + Hamming)
23//! ```no_run
24//! use anndists::dist::DistHamming;
25//! use rust_diskann::{DiskANN, DiskAnnParams};
26//! let index: Vec<Vec<u64>> = vec![vec![0u64; 128]; 1000];
27//! let idx = DiskANN::<u64, DistHamming>::build_index_default(&index, DistHamming, "mh.db").unwrap();
28//! let q = vec![0u64; 128];
29//! let _ = idx.search(&q, 10, 64);
30//! ```
31//!
32//! ## File Layout
33//! [ metadata_len:u64 ][ metadata (bincode) ][ padding up to vectors_offset ]
34//! [ vectors (num * dim * T) ][ adjacency (num * max_degree * u32) ]
35//!
36//! `vectors_offset` is a fixed 1 MiB gap by default.
37
38use anndists::prelude::Distance;
39use memmap2::Mmap;
40use rand::{prelude::*, thread_rng};
41use rayon::prelude::*;
42use serde::{Deserialize, Serialize};
43use std::cmp::{Ordering, Reverse};
44use std::collections::{BinaryHeap, HashSet};
45use std::fs::OpenOptions;
46use std::io::{Read, Seek, SeekFrom, Write};
47use std::marker::PhantomData;
48use thiserror::Error;
49
50/// Padding sentinel for adjacency slots (avoid colliding with node 0).
51const PAD_U32: u32 = u32::MAX;
52
53/// Defaults for in-memory DiskANN builds
54pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
55pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
56pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
57/// Default number of refinement passes during graph build
58pub const DISKANN_DEFAULT_PASSES: usize = 1;
59/// Default number of extra random seeds per node per pass during graph build
60pub const DISKANN_DEFAULT_EXTRA_SEEDS: usize = 1;
61
62/// Practical DiskANN-style slack before reverse-neighbor re-pruning.
63/// Legacy C++ DiskANN allows reverse lists to grow to about GRAPH_SLACK_FACTOR * R
64/// before triggering prune, instead of pruning immediately at R.
65const GRAPH_SLACK_FACTOR: f32 = 1.3;
66
67/// Optional bag of knobs if you want to override just a few.
68#[derive(Clone, Copy, Debug)]
69pub struct DiskAnnParams {
70    pub max_degree: usize,
71    pub build_beam_width: usize,
72    pub alpha: f32,
73    /// Number of refinement passes over the graph (>=1).
74    pub passes: usize,
75    /// Extra random seeds per node during each pass (>=0).
76    pub extra_seeds: usize,
77}
78
79impl Default for DiskAnnParams {
80    fn default() -> Self {
81        Self {
82            max_degree: DISKANN_DEFAULT_MAX_DEGREE,
83            build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
84            alpha: DISKANN_DEFAULT_ALPHA,
85            passes: DISKANN_DEFAULT_PASSES,
86            extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
87        }
88    }
89}
90
91/// Custom error type for DiskAnn operations
92#[derive(Debug, Error)]
93pub enum DiskAnnError {
94    /// Represents I/O errors during file operations
95    #[error("I/O error: {0}")]
96    Io(#[from] std::io::Error),
97
98    /// Represents serialization/deserialization errors
99    #[error("Serialization error: {0}")]
100    Bincode(#[from] bincode::Error),
101
102    /// Represents index-specific errors
103    #[error("Index error: {0}")]
104    IndexError(String),
105}
106
107/// Internal metadata structure stored in the index file
108#[derive(Serialize, Deserialize, Debug)]
109struct Metadata {
110    dim: usize,
111    num_vectors: usize,
112    max_degree: usize,
113    medoid_id: u32,
114    vectors_offset: u64,
115    adjacency_offset: u64,
116    elem_size: u8,
117    distance_name: String,
118}
119
120/// Candidate for search/frontier queues
121#[derive(Clone, Copy, Debug)]
122struct Candidate {
123    dist: f32,
124    id: u32,
125}
126impl PartialEq for Candidate {
127    fn eq(&self, other: &Self) -> bool {
128        self.id == other.id && self.dist.to_bits() == other.dist.to_bits()
129    }
130}
131impl Eq for Candidate {}
132impl PartialOrd for Candidate {
133    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
134        Some(
135            self.dist
136                .total_cmp(&other.dist)
137                .then_with(|| self.id.cmp(&other.id)),
138        )
139    }
140}
141impl Ord for Candidate {
142    fn cmp(&self, other: &Self) -> Ordering {
143        self.partial_cmp(other).unwrap_or(Ordering::Equal)
144    }
145}
146
147/// Flat contiguous matrix used during build to improve cache locality.
148///
149/// Rows are stored consecutively in `data`, row-major.
150#[derive(Clone, Debug)]
151struct FlatVectors<T> {
152    data: Vec<T>,
153    dim: usize,
154    n: usize,
155}
156
157impl<T: Copy> FlatVectors<T> {
158    fn from_vecs(vectors: &[Vec<T>]) -> Result<Self, DiskAnnError> {
159        if vectors.is_empty() {
160            return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
161        }
162        let dim = vectors[0].len();
163        for (i, v) in vectors.iter().enumerate() {
164            if v.len() != dim {
165                return Err(DiskAnnError::IndexError(format!(
166                    "Vector {} has dimension {} but expected {}",
167                    i,
168                    v.len(),
169                    dim
170                )));
171            }
172        }
173
174        let n = vectors.len();
175        let mut data = Vec::with_capacity(n * dim);
176        for v in vectors {
177            data.extend_from_slice(v);
178        }
179
180        Ok(Self { data, dim, n })
181    }
182
183    #[inline]
184    fn row(&self, idx: usize) -> &[T] {
185        let start = idx * self.dim;
186        let end = start + self.dim;
187        &self.data[start..end]
188    }
189}
190
191/// Small ordered beam structure used only during build-time greedy search.
192///
193/// It keeps elements in **descending** distance order:
194/// - index 0 is the worst element
195/// - last element is the best element
196///
197/// This makes:
198/// - `best()` cheap via `last()`
199/// - `worst()` cheap via `first()`
200/// - capped beam maintenance simple
201#[derive(Default, Debug)]
202struct OrderedBeam {
203    items: Vec<Candidate>,
204}
205
206impl OrderedBeam {
207    #[inline]
208    fn clear(&mut self) {
209        self.items.clear();
210    }
211
212    #[inline]
213    fn len(&self) -> usize {
214        self.items.len()
215    }
216
217    #[inline]
218    fn is_empty(&self) -> bool {
219        self.items.is_empty()
220    }
221
222    #[inline]
223    fn best(&self) -> Option<Candidate> {
224        self.items.last().copied()
225    }
226
227    #[inline]
228    fn worst(&self) -> Option<Candidate> {
229        self.items.first().copied()
230    }
231
232    #[inline]
233    fn pop_best(&mut self) -> Option<Candidate> {
234        self.items.pop()
235    }
236
237    #[inline]
238    fn reserve(&mut self, cap: usize) {
239        if self.items.capacity() < cap {
240            self.items.reserve(cap - self.items.capacity());
241        }
242    }
243
244    #[inline]
245    fn insert_unbounded(&mut self, cand: Candidate) {
246        let pos = self.items.partition_point(|x| {
247            x.dist > cand.dist || (x.dist.to_bits() == cand.dist.to_bits() && x.id > cand.id)
248        });
249        self.items.insert(pos, cand);
250    }
251
252    #[inline]
253    fn insert_capped(&mut self, cand: Candidate, cap: usize) {
254        if cap == 0 {
255            return;
256        }
257
258        if self.items.len() < cap {
259            self.insert_unbounded(cand);
260            return;
261        }
262
263        // Since items[0] is the worst, only insert if the new candidate is better.
264        let worst = self.items[0];
265        if cand.dist >= worst.dist {
266            return;
267        }
268
269        self.insert_unbounded(cand);
270
271        if self.items.len() > cap {
272            self.items.remove(0);
273        }
274    }
275}
276
277/// Reusable scratch buffers for build-time greedy search.
278/// One instance is created per Rayon worker via `map_init`, so allocations are reused
279/// across many nodes in the build.
280#[derive(Debug)]
281struct BuildScratch {
282    marks: Vec<u32>,
283    epoch: u32,
284
285    visited_ids: Vec<u32>,
286    visited_dists: Vec<f32>,
287
288    frontier: OrderedBeam,
289    work: OrderedBeam,
290
291    seeds: Vec<usize>,
292    candidates: Vec<(u32, f32)>,
293}
294
295impl BuildScratch {
296    fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
297        Self {
298            marks: vec![0u32; n],
299            epoch: 1,
300            visited_ids: Vec::with_capacity(beam_width * 4),
301            visited_dists: Vec::with_capacity(beam_width * 4),
302            frontier: {
303                let mut b = OrderedBeam::default();
304                b.reserve(beam_width * 2);
305                b
306            },
307            work: {
308                let mut b = OrderedBeam::default();
309                b.reserve(beam_width * 2);
310                b
311            },
312            seeds: Vec::with_capacity(1 + extra_seeds),
313            candidates: Vec::with_capacity(beam_width * (4 + extra_seeds) + max_degree * 2),
314        }
315    }
316
317    #[inline]
318    fn reset_search(&mut self) {
319        self.epoch = self.epoch.wrapping_add(1);
320        if self.epoch == 0 {
321            self.marks.fill(0);
322            self.epoch = 1;
323        }
324        self.visited_ids.clear();
325        self.visited_dists.clear();
326        self.frontier.clear();
327        self.work.clear();
328    }
329
330    #[inline]
331    fn is_marked(&self, idx: usize) -> bool {
332        self.marks[idx] == self.epoch
333    }
334
335    #[inline]
336    fn mark_with_dist(&mut self, idx: usize, dist: f32) {
337        self.marks[idx] = self.epoch;
338        self.visited_ids.push(idx as u32);
339        self.visited_dists.push(dist);
340    }
341}
342
343/// Main struct representing a DiskANN index (generic over vector element `T` and distance `D`)
344pub struct DiskANN<T, D>
345where
346    T: bytemuck::Pod + Copy + Send + Sync + 'static,
347    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
348{
349    /// Dimensionality of vectors in the index
350    pub dim: usize,
351    /// Number of vectors in the index
352    pub num_vectors: usize,
353    /// Maximum number of edges per node
354    pub max_degree: usize,
355    /// Informational: type name of the distance (from metadata)
356    pub distance_name: String,
357
358    /// ID of the medoid (used as entry point)
359    medoid_id: u32,
360    // Offsets
361    vectors_offset: u64,
362    adjacency_offset: u64,
363
364    /// Memory-mapped file
365    mmap: Mmap,
366
367    /// The distance strategy
368    dist: D,
369
370    /// keep `T` in the type so the compiler knows about it
371    _phantom: PhantomData<T>,
372}
373
374// constructors
375
376impl<T, D> DiskANN<T, D>
377where
378    T: bytemuck::Pod + Copy + Send + Sync + 'static,
379    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
380{
381    /// Build with default parameters: (M=64, L=128, alpha=1.2, passes=2, extra_seeds=2).
382    pub fn build_index_default(
383        vectors: &[Vec<T>],
384        dist: D,
385        file_path: &str,
386    ) -> Result<Self, DiskAnnError> {
387        Self::build_index(
388            vectors,
389            DISKANN_DEFAULT_MAX_DEGREE,
390            DISKANN_DEFAULT_BUILD_BEAM,
391            DISKANN_DEFAULT_ALPHA,
392            DISKANN_DEFAULT_PASSES,
393            DISKANN_DEFAULT_EXTRA_SEEDS,
394            dist,
395            file_path,
396        )
397    }
398
399    /// Build with a `DiskAnnParams` bundle.
400    pub fn build_index_with_params(
401        vectors: &[Vec<T>],
402        dist: D,
403        file_path: &str,
404        p: DiskAnnParams,
405    ) -> Result<Self, DiskAnnError> {
406        Self::build_index(
407            vectors,
408            p.max_degree,
409            p.build_beam_width,
410            p.alpha,
411            p.passes,
412            p.extra_seeds,
413            dist,
414            file_path,
415        )
416    }
417
418    /// Opens an existing index file, supplying the distance strategy explicitly.
419    pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
420        let mut file = OpenOptions::new().read(true).write(false).open(path)?;
421
422        // Read metadata length
423        let mut buf8 = [0u8; 8];
424        file.seek(SeekFrom::Start(0))?;
425        file.read_exact(&mut buf8)?;
426        let md_len = u64::from_le_bytes(buf8);
427
428        // Read metadata
429        let mut md_bytes = vec![0u8; md_len as usize];
430        file.read_exact(&mut md_bytes)?;
431        let metadata: Metadata = bincode::deserialize(&md_bytes)?;
432
433        let mmap = unsafe { memmap2::Mmap::map(&file)? };
434
435        // Validate element size vs T
436        let want = std::mem::size_of::<T>() as u8;
437        if metadata.elem_size != want {
438            return Err(DiskAnnError::IndexError(format!(
439                "element size mismatch: file has {}B, T is {}B",
440                metadata.elem_size, want
441            )));
442        }
443
444        // Optional sanity/logging: warn if type differs from recorded name
445        let expected = std::any::type_name::<D>();
446        if metadata.distance_name != expected {
447            eprintln!(
448                "Warning: index recorded distance `{}` but you opened with `{}`",
449                metadata.distance_name, expected
450            );
451        }
452
453        Ok(Self {
454            dim: metadata.dim,
455            num_vectors: metadata.num_vectors,
456            max_degree: metadata.max_degree,
457            distance_name: metadata.distance_name,
458            medoid_id: metadata.medoid_id,
459            vectors_offset: metadata.vectors_offset,
460            adjacency_offset: metadata.adjacency_offset,
461            mmap,
462            dist,
463            _phantom: PhantomData,
464        })
465    }
466}
467
468/// Extra sugar when your distance type implements `Default`.
469impl<T, D> DiskANN<T, D>
470where
471    T: bytemuck::Pod + Copy + Send + Sync + 'static,
472    D: Distance<T> + Default + Send + Sync + Copy + Clone + 'static,
473{
474    /// Build with default params **and** `D::default()` metric.
475    pub fn build_index_default_metric(
476        vectors: &[Vec<T>],
477        file_path: &str,
478    ) -> Result<Self, DiskAnnError> {
479        Self::build_index_default(vectors, D::default(), file_path)
480    }
481
482    /// Open an index using `D::default()` as the distance (matches what you built with).
483    pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
484        Self::open_index_with(path, D::default())
485    }
486}
487
488impl<T, D> DiskANN<T, D>
489where
490    T: bytemuck::Pod + Copy + Send + Sync + 'static,
491    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
492{
493    /// Builds a new index from provided vectors
494    ///
495    /// # Arguments
496    /// * `vectors` - The vectors to index (slice of Vec<T>)
497    /// * `max_degree` - Maximum edges per node (M ~ 24-64+)
498    /// * `build_beam_width` - Construction L (e.g., 128-400)
499    /// * `alpha` - Pruning parameter (1.2–2.0)
500    /// * `passes` - Refinement passes over the graph (>=1)
501    /// * `extra_seeds` - Extra random seeds per node per pass (>=0)
502    /// * `dist` - Any `anndists::Distance<T>`
503    /// * `file_path` - Path of index file
504    pub fn build_index(
505        vectors: &[Vec<T>],
506        max_degree: usize,
507        build_beam_width: usize,
508        alpha: f32,
509        passes: usize,
510        extra_seeds: usize,
511        dist: D,
512        file_path: &str,
513    ) -> Result<Self, DiskAnnError> {
514        let flat = FlatVectors::from_vecs(vectors)?;
515
516        let num_vectors = flat.n;
517        let dim = flat.dim;
518
519        let mut file = OpenOptions::new()
520            .create(true)
521            .write(true)
522            .read(true)
523            .truncate(true)
524            .open(file_path)?;
525
526        // Reserve space for metadata (we'll write it after data)
527        let vectors_offset = 1024 * 1024;
528        assert_eq!(
529            (vectors_offset as usize) % std::mem::align_of::<T>(),
530            0,
531            "vectors_offset must be aligned for T"
532        );
533
534        let elem_sz = std::mem::size_of::<T>() as u64;
535        let total_vector_bytes = (num_vectors as u64) * (dim as u64) * elem_sz;
536
537        // Write vectors contiguous
538        file.seek(SeekFrom::Start(vectors_offset as u64))?;
539        file.write_all(bytemuck::cast_slice::<T, u8>(&flat.data))?;
540
541        // Compute medoid using flat storage
542        let medoid_id = calculate_medoid(&flat, dist);
543
544        // Build graph
545        let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
546        let graph = build_vamana_graph(
547            &flat,
548            max_degree,
549            build_beam_width,
550            alpha,
551            passes,
552            extra_seeds,
553            dist,
554            medoid_id as u32,
555        );
556
557        // Write adjacency lists
558        file.seek(SeekFrom::Start(adjacency_offset))?;
559        for neighbors in &graph {
560            let mut padded = neighbors.clone();
561            padded.resize(max_degree, PAD_U32);
562            let bytes = bytemuck::cast_slice::<u32, u8>(&padded);
563            file.write_all(bytes)?;
564        }
565
566        // Write metadata
567        let metadata = Metadata {
568            dim,
569            num_vectors,
570            max_degree,
571            medoid_id: medoid_id as u32,
572            vectors_offset: vectors_offset as u64,
573            adjacency_offset,
574            elem_size: std::mem::size_of::<T>() as u8,
575            distance_name: std::any::type_name::<D>().to_string(),
576        };
577
578        let md_bytes = bincode::serialize(&metadata)?;
579        file.seek(SeekFrom::Start(0))?;
580        let md_len = md_bytes.len() as u64;
581        file.write_all(&md_len.to_le_bytes())?;
582        file.write_all(&md_bytes)?;
583        file.sync_all()?;
584
585        // Memory map the file
586        let mmap = unsafe { memmap2::Mmap::map(&file)? };
587
588        Ok(Self {
589            dim,
590            num_vectors,
591            max_degree,
592            distance_name: metadata.distance_name,
593            medoid_id: metadata.medoid_id,
594            vectors_offset: metadata.vectors_offset,
595            adjacency_offset: metadata.adjacency_offset,
596            mmap,
597            dist,
598            _phantom: PhantomData,
599        })
600    }
601
602    /// Searches the index for nearest neighbors using a best-first beam search.
603    /// Termination rule: continue while the best frontier can still improve the worst in working set.
604    pub fn search_with_dists(&self, query: &[T], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
605        assert_eq!(
606            query.len(),
607            self.dim,
608            "Query dim {} != index dim {}",
609            query.len(),
610            self.dim
611        );
612
613        let mut visited = HashSet::new();
614        let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
615        let mut w: BinaryHeap<Candidate> = BinaryHeap::new();
616
617        let start_dist = self.distance_to(query, self.medoid_id as usize);
618        let start = Candidate {
619            dist: start_dist,
620            id: self.medoid_id,
621        };
622        frontier.push(Reverse(start));
623        w.push(start);
624        visited.insert(self.medoid_id);
625
626        while let Some(Reverse(best)) = frontier.peek().copied() {
627            if w.len() >= beam_width {
628                if let Some(worst) = w.peek() {
629                    if best.dist >= worst.dist {
630                        break;
631                    }
632                }
633            }
634            let Reverse(current) = frontier.pop().unwrap();
635
636            for &nb in self.get_neighbors(current.id) {
637                if nb == PAD_U32 {
638                    continue;
639                }
640                if !visited.insert(nb) {
641                    continue;
642                }
643
644                let d = self.distance_to(query, nb as usize);
645                let cand = Candidate { dist: d, id: nb };
646
647                if w.len() < beam_width {
648                    w.push(cand);
649                    frontier.push(Reverse(cand));
650                } else if d < w.peek().unwrap().dist {
651                    w.pop();
652                    w.push(cand);
653                    frontier.push(Reverse(cand));
654                }
655            }
656        }
657
658        let mut results: Vec<_> = w.into_vec();
659        results.sort_by(|a, b| a.dist.total_cmp(&b.dist));
660        results.truncate(k);
661        results.into_iter().map(|c| (c.id, c.dist)).collect()
662    }
663
664    /// search but only return neighbor ids
665    pub fn search(&self, query: &[T], k: usize, beam_width: usize) -> Vec<u32> {
666        self.search_with_dists(query, k, beam_width)
667            .into_iter()
668            .map(|(id, _dist)| id)
669            .collect()
670    }
671
672    /// Gets the neighbors of a node from the (fixed-degree) adjacency region
673    fn get_neighbors(&self, node_id: u32) -> &[u32] {
674        let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
675        let start = offset as usize;
676        let end = start + (self.max_degree * 4);
677        let bytes = &self.mmap[start..end];
678        bytemuck::cast_slice(bytes)
679    }
680
681    /// Computes distance between `query` and vector `idx`
682    fn distance_to(&self, query: &[T], idx: usize) -> f32 {
683        let elem_sz = std::mem::size_of::<T>();
684        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
685        let start = offset as usize;
686        let end = start + (self.dim * elem_sz);
687        let bytes = &self.mmap[start..end];
688        let vector: &[T] = bytemuck::cast_slice(bytes);
689        self.dist.eval(query, vector)
690    }
691
692    /// Gets a vector from the index
693    pub fn get_vector(&self, idx: usize) -> Vec<T> {
694        let elem_sz = std::mem::size_of::<T>();
695        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
696        let start = offset as usize;
697        let end = start + (self.dim * elem_sz);
698        let bytes = &self.mmap[start..end];
699        let vector: &[T] = bytemuck::cast_slice(bytes);
700        vector.to_vec()
701    }
702}
703
704/// Calculates the medoid using flat contiguous storage.
705fn calculate_medoid<T, D>(vectors: &FlatVectors<T>, dist: D) -> usize
706where
707    T: bytemuck::Pod + Copy + Send + Sync,
708    D: Distance<T> + Copy + Sync,
709{
710    let n = vectors.n;
711    let k = 8.min(n);
712    let mut rng = thread_rng();
713    let pivots: Vec<usize> = (0..k).map(|_| rng.gen_range(0..n)).collect();
714
715    let (best_idx, _best_score) = (0..n)
716        .into_par_iter()
717        .map(|i| {
718            let vi = vectors.row(i);
719            let score: f32 = pivots.iter().map(|&p| dist.eval(vi, vectors.row(p))).sum();
720            (i, score)
721        })
722        .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
723
724    best_idx
725}
726
727/// Build Vamana-like graph
728fn build_vamana_graph<T, D>(
729    vectors: &FlatVectors<T>,
730    max_degree: usize,
731    build_beam_width: usize,
732    alpha: f32,
733    passes: usize,
734    extra_seeds: usize,
735    dist: D,
736    medoid_id: u32,
737) -> Vec<Vec<u32>>
738where
739    T: bytemuck::Pod + Copy + Send + Sync,
740    D: Distance<T> + Copy + Sync,
741{
742    let n = vectors.n;
743    let mut graph = vec![Vec::<u32>::new(); n];
744
745    // Random R-out directed graph bootstrap
746    {
747        let mut rng = thread_rng();
748        let target = max_degree.min(n.saturating_sub(1));
749
750        for i in 0..n {
751            let mut s = HashSet::with_capacity(target);
752            while s.len() < target {
753                let nb = rng.gen_range(0..n);
754                if nb != i {
755                    s.insert(nb as u32);
756                }
757            }
758            graph[i] = s.into_iter().collect();
759        }
760    }
761
762    let passes = passes.max(1);
763    let mut rng = thread_rng();
764
765    for pass_idx in 0..passes {
766        let pass_alpha = if passes == 1 {
767            alpha
768        } else if pass_idx == 0 {
769            1.0
770        } else {
771            alpha
772        };
773
774        let mut order: Vec<usize> = (0..n).collect();
775        order.shuffle(&mut rng);
776
777        // Important: practical incremental insertion
778        for &u in &order {
779            let mut scratch = BuildScratch::new(n, build_beam_width, max_degree, extra_seeds);
780            scratch.candidates.clear();
781
782            // Legacy DiskANN-style candidate pool starts from search results,
783            // not whole V. We also include current adjacency as a cheap prior.
784            for &nb in &graph[u] {
785                let d = dist.eval(vectors.row(u), vectors.row(nb as usize));
786                scratch.candidates.push((nb, d));
787            }
788
789            // Seeds: medoid + distinct random starts
790            scratch.seeds.clear();
791            scratch.seeds.push(medoid_id as usize);
792            while scratch.seeds.len() < 1 + extra_seeds {
793                let s = rng.gen_range(0..n);
794                if !scratch.seeds.contains(&s) {
795                    scratch.seeds.push(s);
796                }
797            }
798
799            let seeds = scratch.seeds.clone();
800            for start in seeds {
801                greedy_search_visited_collect(
802                    vectors.row(u),
803                    vectors,
804                    &graph,
805                    start,
806                    build_beam_width,
807                    dist,
808                    &mut scratch,
809                );
810
811                for i in 0..scratch.visited_ids.len() {
812                    scratch
813                        .candidates
814                        .push((scratch.visited_ids[i], scratch.visited_dists[i]));
815                }
816            }
817
818            // Deduplicate by id, keep best distance
819            scratch.candidates.sort_by(|a, b| a.0.cmp(&b.0));
820            scratch.candidates.dedup_by(|a, b| {
821                if a.0 == b.0 {
822                    if a.1 < b.1 {
823                        *b = *a;
824                    }
825                    true
826                } else {
827                    false
828                }
829            });
830
831            let pruned = prune_neighbors(
832                u,
833                &scratch.candidates,
834                vectors,
835                max_degree,
836                pass_alpha,
837                dist,
838            );
839
840            // Set u's outgoing list
841            graph[u] = pruned.clone();
842
843            // Reverse insertion with slack-triggered overflow re-pruning
844            inter_insert_with_slack(
845                &mut graph,
846                u as u32,
847                &pruned,
848                vectors,
849                max_degree,
850                pass_alpha,
851                dist,
852            );
853        }
854    }
855
856    // Final cleanup: exactly the practical thing to keep.
857    graph
858        .into_iter()
859        .enumerate()
860        .map(|(u, neigh)| {
861            if neigh.len() <= max_degree {
862                return neigh;
863            }
864            let mut ids = neigh;
865            ids.sort_unstable();
866            ids.dedup();
867
868            let pool: Vec<(u32, f32)> = ids
869                .into_iter()
870                .filter(|&id| id as usize != u)
871                .map(|id| (id, dist.eval(vectors.row(u), vectors.row(id as usize))))
872                .collect();
873
874            prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
875        })
876        .collect()
877}
878
879/// Insert reverse edge `src -> dst` in the practical DiskANN style:
880/// - if dst's degree is still below slack * R, just append
881/// - otherwise, gather dst's current neighbors + src, deduplicate, and re-prune dst
882fn inter_insert_with_slack<T, D>(
883    graph: &mut [Vec<u32>],
884    src: u32,
885    pruned_list: &[u32],
886    vectors: &FlatVectors<T>,
887    max_degree: usize,
888    alpha: f32,
889    dist: D,
890) where
891    T: bytemuck::Pod + Copy + Send + Sync,
892    D: Distance<T> + Copy,
893{
894    let slack_limit = ((GRAPH_SLACK_FACTOR * max_degree as f32).ceil() as usize).max(max_degree);
895
896    for &dst in pruned_list {
897        let dst_usize = dst as usize;
898        let src_usize = src as usize;
899
900        if dst_usize == src_usize {
901            continue;
902        }
903
904        // already present -> nothing to do
905        if graph[dst_usize].contains(&src) {
906            continue;
907        }
908
909        if graph[dst_usize].len() < slack_limit {
910            graph[dst_usize].push(src);
911            continue;
912        }
913
914        // overflow: rebuild candidate pool and prune only this destination node
915        let mut ids = graph[dst_usize].clone();
916        ids.push(src);
917        ids.sort_unstable();
918        ids.dedup();
919
920        let pool: Vec<(u32, f32)> = ids
921            .into_iter()
922            .filter(|&id| id as usize != dst_usize)
923            .map(|id| (id, dist.eval(vectors.row(dst_usize), vectors.row(id as usize))))
924            .collect();
925
926        graph[dst_usize] = prune_neighbors(dst_usize, &pool, vectors, max_degree, alpha, dist);
927    }
928}
929
930/// Build-time greedy search:
931/// - dense visited marks instead of HashMap/HashSet
932/// - visited_ids + visited_dists instead of recomputing distances later
933/// - ordered beams instead of BinaryHeap
934///
935/// Output is written into `scratch.visited_ids` and `scratch.visited_dists`.
936fn greedy_search_visited_collect<T, D>(
937    query: &[T],
938    vectors: &FlatVectors<T>,
939    graph: &[Vec<u32>],
940    start_id: usize,
941    beam_width: usize,
942    dist: D,
943    scratch: &mut BuildScratch,
944) where
945    T: bytemuck::Pod + Copy + Send + Sync,
946    D: Distance<T> + Copy,
947{
948    scratch.reset_search();
949
950    let start_dist = dist.eval(query, vectors.row(start_id));
951    let start = Candidate {
952        dist: start_dist,
953        id: start_id as u32,
954    };
955
956    scratch.frontier.insert_unbounded(start);
957    scratch.work.insert_capped(start, beam_width);
958    scratch.mark_with_dist(start_id, start_dist);
959
960    while !scratch.frontier.is_empty() {
961        let best = scratch.frontier.best().unwrap();
962        if scratch.work.len() >= beam_width {
963            if let Some(worst) = scratch.work.worst() {
964                if best.dist >= worst.dist {
965                    break;
966                }
967            }
968        }
969
970        let cur = scratch.frontier.pop_best().unwrap();
971
972        for &nb in &graph[cur.id as usize] {
973            let nb_usize = nb as usize;
974            if scratch.is_marked(nb_usize) {
975                continue;
976            }
977
978            let d = dist.eval(query, vectors.row(nb_usize));
979            scratch.mark_with_dist(nb_usize, d);
980
981            let cand = Candidate { dist: d, id: nb };
982
983            if scratch.work.len() < beam_width {
984                scratch.work.insert_unbounded(cand);
985                scratch.frontier.insert_unbounded(cand);
986            } else if let Some(worst) = scratch.work.worst() {
987                if d < worst.dist {
988                    scratch.work.insert_capped(cand, beam_width);
989                    scratch.frontier.insert_unbounded(cand);
990                }
991            }
992        }
993    }
994}
995
996/// α-pruning
997fn prune_neighbors<T, D>(
998    node_id: usize,
999    candidates: &[(u32, f32)],
1000    vectors: &FlatVectors<T>,
1001    max_degree: usize,
1002    alpha: f32,
1003    dist: D,
1004) -> Vec<u32>
1005where
1006    T: bytemuck::Pod + Copy + Send + Sync,
1007    D: Distance<T> + Copy,
1008{
1009    if candidates.is_empty() || max_degree == 0 {
1010        return Vec::new();
1011    }
1012
1013    // Sort by distance from node_id, nearest first.
1014    let mut sorted = candidates.to_vec();
1015    sorted.sort_by(|a, b| a.1.total_cmp(&b.1));
1016
1017    // Remove self and duplicate ids while keeping the best (nearest) occurrence.
1018    let mut uniq = Vec::<(u32, f32)>::with_capacity(sorted.len());
1019    let mut last_id: Option<u32> = None;
1020    for &(cand_id, cand_dist) in &sorted {
1021        if cand_id as usize == node_id {
1022            continue;
1023        }
1024        if last_id == Some(cand_id) {
1025            continue;
1026        }
1027        uniq.push((cand_id, cand_dist));
1028        last_id = Some(cand_id);
1029    }
1030
1031    let mut pruned = Vec::<u32>::with_capacity(max_degree);
1032
1033    // Pure robust pruning: DO NOT backfill rejected candidates.
1034    for &(cand_id, cand_dist_to_node) in &uniq {
1035        let mut occluded = false;
1036
1037        for &sel_id in &pruned {
1038            let d_cand_sel = dist.eval(
1039                vectors.row(cand_id as usize),
1040                vectors.row(sel_id as usize),
1041            );
1042
1043            // If selected neighbor sel_id "covers" cand_id, reject cand_id.
1044            if alpha * d_cand_sel <= cand_dist_to_node {
1045                occluded = true;
1046                break;
1047            }
1048        }
1049
1050        if !occluded {
1051            pruned.push(cand_id);
1052            if pruned.len() >= max_degree {
1053                break;
1054            }
1055        }
1056    }
1057
1058    pruned
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063    use super::*;
1064    use anndists::dist::{DistCosine, DistL2};
1065    use rand::Rng;
1066    use std::fs;
1067
1068    fn euclid(a: &[f32], b: &[f32]) -> f32 {
1069        a.iter()
1070            .zip(b)
1071            .map(|(x, y)| (x - y) * (x - y))
1072            .sum::<f32>()
1073            .sqrt()
1074    }
1075
1076    #[test]
1077    fn test_small_index_l2() {
1078        let path = "test_small_l2.db";
1079        let _ = fs::remove_file(path);
1080
1081        let vectors = vec![
1082            vec![0.0, 0.0],
1083            vec![1.0, 0.0],
1084            vec![0.0, 1.0],
1085            vec![1.0, 1.0],
1086            vec![0.5, 0.5],
1087        ];
1088
1089        let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1090
1091        let q = vec![0.1, 0.1];
1092        let nns = index.search(&q, 3, 8);
1093        assert_eq!(nns.len(), 3);
1094
1095        let v = index.get_vector(nns[0] as usize);
1096        assert!(euclid(&q, &v) < 1.0);
1097
1098        let _ = fs::remove_file(path);
1099    }
1100
1101    #[test]
1102    fn test_cosine() {
1103        let path = "test_cosine.db";
1104        let _ = fs::remove_file(path);
1105
1106        let vectors = vec![
1107            vec![1.0, 0.0, 0.0],
1108            vec![0.0, 1.0, 0.0],
1109            vec![0.0, 0.0, 1.0],
1110            vec![1.0, 1.0, 0.0],
1111            vec![1.0, 0.0, 1.0],
1112        ];
1113
1114        let index =
1115            DiskANN::<f32, DistCosine>::build_index_default(&vectors, DistCosine, path).unwrap();
1116
1117        let q = vec![2.0, 0.0, 0.0];
1118        let nns = index.search(&q, 2, 8);
1119        assert_eq!(nns.len(), 2);
1120
1121        let v = index.get_vector(nns[0] as usize);
1122        let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
1123        let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1124        let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
1125        let cos = dot / (n1 * n2);
1126        assert!(cos > 0.7);
1127
1128        let _ = fs::remove_file(path);
1129    }
1130
1131    #[test]
1132    fn test_persistence_and_open() {
1133        let path = "test_persist.db";
1134        let _ = fs::remove_file(path);
1135
1136        let vectors = vec![
1137            vec![0.0, 0.0],
1138            vec![1.0, 0.0],
1139            vec![0.0, 1.0],
1140            vec![1.0, 1.0],
1141        ];
1142
1143        {
1144            let _idx =
1145                DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1146        }
1147
1148        let idx2 = DiskANN::<f32, DistL2>::open_index_default_metric(path).unwrap();
1149        assert_eq!(idx2.num_vectors, 4);
1150        assert_eq!(idx2.dim, 2);
1151
1152        let q = vec![0.9, 0.9];
1153        let res = idx2.search(&q, 2, 8);
1154        assert_eq!(res[0], 3);
1155
1156        let _ = fs::remove_file(path);
1157    }
1158
1159    #[test]
1160    fn test_grid_connectivity() {
1161        let path = "test_grid.db";
1162        let _ = fs::remove_file(path);
1163
1164        let mut vectors = Vec::new();
1165        for i in 0..5 {
1166            for j in 0..5 {
1167                vectors.push(vec![i as f32, j as f32]);
1168            }
1169        }
1170
1171        let index = DiskANN::<f32, DistL2>::build_index_with_params(
1172            &vectors,
1173            DistL2,
1174            path,
1175            DiskAnnParams {
1176                max_degree: 4,
1177                build_beam_width: 64,
1178                alpha: 1.5,
1179                passes: DISKANN_DEFAULT_PASSES,
1180                extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1181            },
1182        )
1183        .unwrap();
1184
1185        for target in 0..vectors.len() {
1186            let q = &vectors[target];
1187            let nns = index.search(q, 10, 32);
1188            if !nns.contains(&(target as u32)) {
1189                let v = index.get_vector(nns[0] as usize);
1190                assert!(euclid(q, &v) < 2.0);
1191            }
1192            for &nb in nns.iter().take(5) {
1193                let v = index.get_vector(nb as usize);
1194                assert!(euclid(q, &v) < 5.0);
1195            }
1196        }
1197
1198        let _ = fs::remove_file(path);
1199    }
1200
1201    #[test]
1202    fn test_medium_random() {
1203        let path = "test_medium.db";
1204        let _ = fs::remove_file(path);
1205
1206        let n = 200usize;
1207        let d = 32usize;
1208        let mut rng = rand::thread_rng();
1209        let vectors: Vec<Vec<f32>> = (0..n)
1210            .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
1211            .collect();
1212
1213        let index = DiskANN::<f32, DistL2>::build_index_with_params(
1214            &vectors,
1215            DistL2,
1216            path,
1217            DiskAnnParams {
1218                max_degree: 32,
1219                build_beam_width: 128,
1220                alpha: 1.2,
1221                passes: DISKANN_DEFAULT_PASSES,
1222                extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1223            },
1224        )
1225        .unwrap();
1226
1227        let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1228        let res = index.search(&q, 10, 64);
1229        assert_eq!(res.len(), 10);
1230
1231        let dists: Vec<f32> = res
1232            .iter()
1233            .map(|&id| {
1234                let v = index.get_vector(id as usize);
1235                euclid(&q, &v)
1236            })
1237            .collect();
1238        let mut sorted = dists.clone();
1239        sorted.sort_by(|a, b| a.total_cmp(b));
1240        assert_eq!(dists, sorted);
1241
1242        let _ = fs::remove_file(path);
1243    }
1244}