Skip to main content

rust_diskann/
lib.rs

1//! # DiskAnn (generic over `anndists::Distance<T>`)
2//!
3//! An on-disk DiskANN library that:
4//! - Builds a Vamana-style graph (greedy + α-pruning) in memory
5//! - Writes vectors + fixed-degree adjacency to a single file
6//! - Memory-maps the file for low-overhead reads
7//! - Is **generic over any Distance<T>** from `anndists` (e.g. L2 on `f32`, Cosine on `f32`,
8//!   Hamming on `u64`, …)
9//!
10//! ## Example (f32 + L2)
11//! ```no_run
12//! use anndists::dist::DistL2;
13//! use rust_diskann::{DiskANN, DiskAnnParams};
14//!
15//! let vectors: Vec<Vec<f32>> = vec![vec![0.0; 128]; 1000];
16//! let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, "index.db").unwrap();
17//!
18//! let q = vec![0.0; 128];
19//! let nns = index.search(&q, 10, 64);
20//! ```
21//!
22//! ## Example (u64 + Hamming)
23//! ```no_run
24//! use anndists::dist::DistHamming;
25//! use rust_diskann::{DiskANN, DiskAnnParams};
26//! let index: Vec<Vec<u64>> = vec![vec![0u64; 128]; 1000];
27//! let idx = DiskANN::<u64, DistHamming>::build_index_default(&index, DistHamming, "mh.db").unwrap();
28//! let q = vec![0u64; 128];
29//! let _ = idx.search(&q, 10, 64);
30//! ```
31//!
32//! ## File Layout
33//! [ metadata_len:u64 ][ metadata (bincode) ][ padding up to vectors_offset ]
34//! [ vectors (num * dim * T) ][ adjacency (num * max_degree * u32) ]
35//!
36//! `vectors_offset` is a fixed 1 MiB gap by default.
37
38use anndists::prelude::Distance;
39use memmap2::Mmap;
40use rand::{prelude::*, thread_rng};
41use rayon::prelude::*;
42use serde::{Deserialize, Serialize};
43use std::cmp::{Ordering, Reverse};
44use std::collections::{BinaryHeap, HashSet};
45use std::fs::OpenOptions;
46use std::io::{Read, Seek, SeekFrom, Write};
47use std::marker::PhantomData;
48use thiserror::Error;
49
50/// Padding sentinel for adjacency slots (avoid colliding with node 0).
51const PAD_U32: u32 = u32::MAX;
52
53/// Defaults for in-memory DiskANN builds
54pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
55pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
56pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
57/// Default number of refinement passes during graph build
58pub const DISKANN_DEFAULT_PASSES: usize = 2;
59/// Default number of extra random seeds per node per pass during graph build
60pub const DISKANN_DEFAULT_EXTRA_SEEDS: usize = 2;
61
62/// Optional bag of knobs if you want to override just a few.
63#[derive(Clone, Copy, Debug)]
64pub struct DiskAnnParams {
65    pub max_degree: usize,
66    pub build_beam_width: usize,
67    pub alpha: f32,
68    /// Number of refinement passes over the graph (>=1).
69    pub passes: usize,
70    /// Extra random seeds per node during each pass (>=0).
71    pub extra_seeds: usize,
72}
73
74impl Default for DiskAnnParams {
75    fn default() -> Self {
76        Self {
77            max_degree: DISKANN_DEFAULT_MAX_DEGREE,
78            build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
79            alpha: DISKANN_DEFAULT_ALPHA,
80            passes: DISKANN_DEFAULT_PASSES,
81            extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
82        }
83    }
84}
85
86/// Custom error type for DiskAnn operations
87#[derive(Debug, Error)]
88pub enum DiskAnnError {
89    /// Represents I/O errors during file operations
90    #[error("I/O error: {0}")]
91    Io(#[from] std::io::Error),
92
93    /// Represents serialization/deserialization errors
94    #[error("Serialization error: {0}")]
95    Bincode(#[from] bincode::Error),
96
97    /// Represents index-specific errors
98    #[error("Index error: {0}")]
99    IndexError(String),
100}
101
102/// Internal metadata structure stored in the index file
103#[derive(Serialize, Deserialize, Debug)]
104struct Metadata {
105    dim: usize,
106    num_vectors: usize,
107    max_degree: usize,
108    medoid_id: u32,
109    vectors_offset: u64,
110    adjacency_offset: u64,
111    elem_size: u8,
112    distance_name: String,
113}
114
115/// Candidate for search/frontier queues
116#[derive(Clone, Copy, Debug)]
117struct Candidate {
118    dist: f32,
119    id: u32,
120}
121impl PartialEq for Candidate {
122    fn eq(&self, other: &Self) -> bool {
123        self.id == other.id && self.dist.to_bits() == other.dist.to_bits()
124    }
125}
126impl Eq for Candidate {}
127impl PartialOrd for Candidate {
128    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
129        Some(
130            self.dist
131                .total_cmp(&other.dist)
132                .then_with(|| self.id.cmp(&other.id)),
133        )
134    }
135}
136impl Ord for Candidate {
137    fn cmp(&self, other: &Self) -> Ordering {
138        self.partial_cmp(other).unwrap_or(Ordering::Equal)
139    }
140}
141
142/// Flat contiguous matrix used during build to improve cache locality.
143///
144/// Rows are stored consecutively in `data`, row-major.
145#[derive(Clone, Debug)]
146struct FlatVectors<T> {
147    data: Vec<T>,
148    dim: usize,
149    n: usize,
150}
151
152impl<T: Copy> FlatVectors<T> {
153    fn from_vecs(vectors: &[Vec<T>]) -> Result<Self, DiskAnnError> {
154        if vectors.is_empty() {
155            return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
156        }
157        let dim = vectors[0].len();
158        for (i, v) in vectors.iter().enumerate() {
159            if v.len() != dim {
160                return Err(DiskAnnError::IndexError(format!(
161                    "Vector {} has dimension {} but expected {}",
162                    i,
163                    v.len(),
164                    dim
165                )));
166            }
167        }
168
169        let n = vectors.len();
170        let mut data = Vec::with_capacity(n * dim);
171        for v in vectors {
172            data.extend_from_slice(v);
173        }
174
175        Ok(Self { data, dim, n })
176    }
177
178    #[inline]
179    fn row(&self, idx: usize) -> &[T] {
180        let start = idx * self.dim;
181        let end = start + self.dim;
182        &self.data[start..end]
183    }
184}
185
186/// Small ordered beam structure used only during build-time greedy search.
187///
188/// It keeps elements in **descending** distance order:
189/// - index 0 is the worst element
190/// - last element is the best element
191///
192/// This makes:
193/// - `best()` cheap via `last()`
194/// - `worst()` cheap via `first()`
195/// - capped beam maintenance simple
196#[derive(Default, Debug)]
197struct OrderedBeam {
198    items: Vec<Candidate>,
199}
200
201impl OrderedBeam {
202    #[inline]
203    fn clear(&mut self) {
204        self.items.clear();
205    }
206
207    #[inline]
208    fn len(&self) -> usize {
209        self.items.len()
210    }
211
212    #[inline]
213    fn is_empty(&self) -> bool {
214        self.items.is_empty()
215    }
216
217    #[inline]
218    fn best(&self) -> Option<Candidate> {
219        self.items.last().copied()
220    }
221
222    #[inline]
223    fn worst(&self) -> Option<Candidate> {
224        self.items.first().copied()
225    }
226
227    #[inline]
228    fn pop_best(&mut self) -> Option<Candidate> {
229        self.items.pop()
230    }
231
232    #[inline]
233    fn reserve(&mut self, cap: usize) {
234        if self.items.capacity() < cap {
235            self.items.reserve(cap - self.items.capacity());
236        }
237    }
238
239    #[inline]
240    fn insert_unbounded(&mut self, cand: Candidate) {
241        let pos = self.items.partition_point(|x| {
242            x.dist > cand.dist || (x.dist.to_bits() == cand.dist.to_bits() && x.id > cand.id)
243        });
244        self.items.insert(pos, cand);
245    }
246
247    #[inline]
248    fn insert_capped(&mut self, cand: Candidate, cap: usize) {
249        if cap == 0 {
250            return;
251        }
252
253        if self.items.len() < cap {
254            self.insert_unbounded(cand);
255            return;
256        }
257
258        // Since items[0] is the worst, only insert if the new candidate is better.
259        let worst = self.items[0];
260        if cand.dist >= worst.dist {
261            return;
262        }
263
264        self.insert_unbounded(cand);
265
266        if self.items.len() > cap {
267            self.items.remove(0);
268        }
269    }
270}
271
272/// Reusable scratch buffers for build-time greedy search.
273/// One instance is created per Rayon worker via `map_init`, so allocations are reused
274/// across many nodes in the build.
275#[derive(Debug)]
276struct BuildScratch {
277    marks: Vec<u32>,
278    epoch: u32,
279
280    visited_ids: Vec<u32>,
281    visited_dists: Vec<f32>,
282
283    frontier: OrderedBeam,
284    work: OrderedBeam,
285
286    seeds: Vec<usize>,
287    candidates: Vec<(u32, f32)>,
288}
289
290impl BuildScratch {
291    fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
292        Self {
293            marks: vec![0u32; n],
294            epoch: 1,
295            visited_ids: Vec::with_capacity(beam_width * 4),
296            visited_dists: Vec::with_capacity(beam_width * 4),
297            frontier: {
298                let mut b = OrderedBeam::default();
299                b.reserve(beam_width * 2);
300                b
301            },
302            work: {
303                let mut b = OrderedBeam::default();
304                b.reserve(beam_width * 2);
305                b
306            },
307            seeds: Vec::with_capacity(1 + extra_seeds),
308            candidates: Vec::with_capacity(beam_width * (4 + extra_seeds) + max_degree * 2),
309        }
310    }
311
312    #[inline]
313    fn reset_search(&mut self) {
314        self.epoch = self.epoch.wrapping_add(1);
315        if self.epoch == 0 {
316            self.marks.fill(0);
317            self.epoch = 1;
318        }
319        self.visited_ids.clear();
320        self.visited_dists.clear();
321        self.frontier.clear();
322        self.work.clear();
323    }
324
325    #[inline]
326    fn is_marked(&self, idx: usize) -> bool {
327        self.marks[idx] == self.epoch
328    }
329
330    #[inline]
331    fn mark_with_dist(&mut self, idx: usize, dist: f32) {
332        self.marks[idx] = self.epoch;
333        self.visited_ids.push(idx as u32);
334        self.visited_dists.push(dist);
335    }
336}
337
338/// Main struct representing a DiskANN index (generic over vector element `T` and distance `D`)
339pub struct DiskANN<T, D>
340where
341    T: bytemuck::Pod + Copy + Send + Sync + 'static,
342    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
343{
344    /// Dimensionality of vectors in the index
345    pub dim: usize,
346    /// Number of vectors in the index
347    pub num_vectors: usize,
348    /// Maximum number of edges per node
349    pub max_degree: usize,
350    /// Informational: type name of the distance (from metadata)
351    pub distance_name: String,
352
353    /// ID of the medoid (used as entry point)
354    medoid_id: u32,
355    // Offsets
356    vectors_offset: u64,
357    adjacency_offset: u64,
358
359    /// Memory-mapped file
360    mmap: Mmap,
361
362    /// The distance strategy
363    dist: D,
364
365    /// keep `T` in the type so the compiler knows about it
366    _phantom: PhantomData<T>,
367}
368
369// constructors
370
371impl<T, D> DiskANN<T, D>
372where
373    T: bytemuck::Pod + Copy + Send + Sync + 'static,
374    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
375{
376    /// Build with default parameters: (M=64, L=128, alpha=1.2, passes=2, extra_seeds=2).
377    pub fn build_index_default(
378        vectors: &[Vec<T>],
379        dist: D,
380        file_path: &str,
381    ) -> Result<Self, DiskAnnError> {
382        Self::build_index(
383            vectors,
384            DISKANN_DEFAULT_MAX_DEGREE,
385            DISKANN_DEFAULT_BUILD_BEAM,
386            DISKANN_DEFAULT_ALPHA,
387            DISKANN_DEFAULT_PASSES,
388            DISKANN_DEFAULT_EXTRA_SEEDS,
389            dist,
390            file_path,
391        )
392    }
393
394    /// Build with a `DiskAnnParams` bundle.
395    pub fn build_index_with_params(
396        vectors: &[Vec<T>],
397        dist: D,
398        file_path: &str,
399        p: DiskAnnParams,
400    ) -> Result<Self, DiskAnnError> {
401        Self::build_index(
402            vectors,
403            p.max_degree,
404            p.build_beam_width,
405            p.alpha,
406            p.passes,
407            p.extra_seeds,
408            dist,
409            file_path,
410        )
411    }
412
413    /// Opens an existing index file, supplying the distance strategy explicitly.
414    pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
415        let mut file = OpenOptions::new().read(true).write(false).open(path)?;
416
417        // Read metadata length
418        let mut buf8 = [0u8; 8];
419        file.seek(SeekFrom::Start(0))?;
420        file.read_exact(&mut buf8)?;
421        let md_len = u64::from_le_bytes(buf8);
422
423        // Read metadata
424        let mut md_bytes = vec![0u8; md_len as usize];
425        file.read_exact(&mut md_bytes)?;
426        let metadata: Metadata = bincode::deserialize(&md_bytes)?;
427
428        let mmap = unsafe { memmap2::Mmap::map(&file)? };
429
430        // Validate element size vs T
431        let want = std::mem::size_of::<T>() as u8;
432        if metadata.elem_size != want {
433            return Err(DiskAnnError::IndexError(format!(
434                "element size mismatch: file has {}B, T is {}B",
435                metadata.elem_size, want
436            )));
437        }
438
439        // Optional sanity/logging: warn if type differs from recorded name
440        let expected = std::any::type_name::<D>();
441        if metadata.distance_name != expected {
442            eprintln!(
443                "Warning: index recorded distance `{}` but you opened with `{}`",
444                metadata.distance_name, expected
445            );
446        }
447
448        Ok(Self {
449            dim: metadata.dim,
450            num_vectors: metadata.num_vectors,
451            max_degree: metadata.max_degree,
452            distance_name: metadata.distance_name,
453            medoid_id: metadata.medoid_id,
454            vectors_offset: metadata.vectors_offset,
455            adjacency_offset: metadata.adjacency_offset,
456            mmap,
457            dist,
458            _phantom: PhantomData,
459        })
460    }
461}
462
463/// Extra sugar when your distance type implements `Default`.
464impl<T, D> DiskANN<T, D>
465where
466    T: bytemuck::Pod + Copy + Send + Sync + 'static,
467    D: Distance<T> + Default + Send + Sync + Copy + Clone + 'static,
468{
469    /// Build with default params **and** `D::default()` metric.
470    pub fn build_index_default_metric(
471        vectors: &[Vec<T>],
472        file_path: &str,
473    ) -> Result<Self, DiskAnnError> {
474        Self::build_index_default(vectors, D::default(), file_path)
475    }
476
477    /// Open an index using `D::default()` as the distance (matches what you built with).
478    pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
479        Self::open_index_with(path, D::default())
480    }
481}
482
483impl<T, D> DiskANN<T, D>
484where
485    T: bytemuck::Pod + Copy + Send + Sync + 'static,
486    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
487{
488    /// Builds a new index from provided vectors
489    ///
490    /// # Arguments
491    /// * `vectors` - The vectors to index (slice of Vec<T>)
492    /// * `max_degree` - Maximum edges per node (M ~ 24-64+)
493    /// * `build_beam_width` - Construction L (e.g., 128-400)
494    /// * `alpha` - Pruning parameter (1.2–2.0)
495    /// * `passes` - Refinement passes over the graph (>=1)
496    /// * `extra_seeds` - Extra random seeds per node per pass (>=0)
497    /// * `dist` - Any `anndists::Distance<T>`
498    /// * `file_path` - Path of index file
499    pub fn build_index(
500        vectors: &[Vec<T>],
501        max_degree: usize,
502        build_beam_width: usize,
503        alpha: f32,
504        passes: usize,
505        extra_seeds: usize,
506        dist: D,
507        file_path: &str,
508    ) -> Result<Self, DiskAnnError> {
509        let flat = FlatVectors::from_vecs(vectors)?;
510
511        let num_vectors = flat.n;
512        let dim = flat.dim;
513
514        let mut file = OpenOptions::new()
515            .create(true)
516            .write(true)
517            .read(true)
518            .truncate(true)
519            .open(file_path)?;
520
521        // Reserve space for metadata (we'll write it after data)
522        let vectors_offset = 1024 * 1024;
523        assert_eq!(
524            (vectors_offset as usize) % std::mem::align_of::<T>(),
525            0,
526            "vectors_offset must be aligned for T"
527        );
528
529        let elem_sz = std::mem::size_of::<T>() as u64;
530        let total_vector_bytes = (num_vectors as u64) * (dim as u64) * elem_sz;
531
532        // Write vectors contiguous
533        file.seek(SeekFrom::Start(vectors_offset as u64))?;
534        file.write_all(bytemuck::cast_slice::<T, u8>(&flat.data))?;
535
536        // Compute medoid using flat storage
537        let medoid_id = calculate_medoid(&flat, dist);
538
539        // Build graph
540        let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
541        let graph = build_vamana_graph(
542            &flat,
543            max_degree,
544            build_beam_width,
545            alpha,
546            passes,
547            extra_seeds,
548            dist,
549            medoid_id as u32,
550        );
551
552        // Write adjacency lists
553        file.seek(SeekFrom::Start(adjacency_offset))?;
554        for neighbors in &graph {
555            let mut padded = neighbors.clone();
556            padded.resize(max_degree, PAD_U32);
557            let bytes = bytemuck::cast_slice::<u32, u8>(&padded);
558            file.write_all(bytes)?;
559        }
560
561        // Write metadata
562        let metadata = Metadata {
563            dim,
564            num_vectors,
565            max_degree,
566            medoid_id: medoid_id as u32,
567            vectors_offset: vectors_offset as u64,
568            adjacency_offset,
569            elem_size: std::mem::size_of::<T>() as u8,
570            distance_name: std::any::type_name::<D>().to_string(),
571        };
572
573        let md_bytes = bincode::serialize(&metadata)?;
574        file.seek(SeekFrom::Start(0))?;
575        let md_len = md_bytes.len() as u64;
576        file.write_all(&md_len.to_le_bytes())?;
577        file.write_all(&md_bytes)?;
578        file.sync_all()?;
579
580        // Memory map the file
581        let mmap = unsafe { memmap2::Mmap::map(&file)? };
582
583        Ok(Self {
584            dim,
585            num_vectors,
586            max_degree,
587            distance_name: metadata.distance_name,
588            medoid_id: metadata.medoid_id,
589            vectors_offset: metadata.vectors_offset,
590            adjacency_offset: metadata.adjacency_offset,
591            mmap,
592            dist,
593            _phantom: PhantomData,
594        })
595    }
596
597    /// Searches the index for nearest neighbors using a best-first beam search.
598    /// Termination rule: continue while the best frontier can still improve the worst in working set.
599    pub fn search_with_dists(&self, query: &[T], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
600        assert_eq!(
601            query.len(),
602            self.dim,
603            "Query dim {} != index dim {}",
604            query.len(),
605            self.dim
606        );
607
608        let mut visited = HashSet::new();
609        let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
610        let mut w: BinaryHeap<Candidate> = BinaryHeap::new();
611
612        let start_dist = self.distance_to(query, self.medoid_id as usize);
613        let start = Candidate {
614            dist: start_dist,
615            id: self.medoid_id,
616        };
617        frontier.push(Reverse(start));
618        w.push(start);
619        visited.insert(self.medoid_id);
620
621        while let Some(Reverse(best)) = frontier.peek().copied() {
622            if w.len() >= beam_width {
623                if let Some(worst) = w.peek() {
624                    if best.dist >= worst.dist {
625                        break;
626                    }
627                }
628            }
629            let Reverse(current) = frontier.pop().unwrap();
630
631            for &nb in self.get_neighbors(current.id) {
632                if nb == PAD_U32 {
633                    continue;
634                }
635                if !visited.insert(nb) {
636                    continue;
637                }
638
639                let d = self.distance_to(query, nb as usize);
640                let cand = Candidate { dist: d, id: nb };
641
642                if w.len() < beam_width {
643                    w.push(cand);
644                    frontier.push(Reverse(cand));
645                } else if d < w.peek().unwrap().dist {
646                    w.pop();
647                    w.push(cand);
648                    frontier.push(Reverse(cand));
649                }
650            }
651        }
652
653        let mut results: Vec<_> = w.into_vec();
654        results.sort_by(|a, b| a.dist.total_cmp(&b.dist));
655        results.truncate(k);
656        results.into_iter().map(|c| (c.id, c.dist)).collect()
657    }
658
659    /// search but only return neighbor ids
660    pub fn search(&self, query: &[T], k: usize, beam_width: usize) -> Vec<u32> {
661        self.search_with_dists(query, k, beam_width)
662            .into_iter()
663            .map(|(id, _dist)| id)
664            .collect()
665    }
666
667    /// Gets the neighbors of a node from the (fixed-degree) adjacency region
668    fn get_neighbors(&self, node_id: u32) -> &[u32] {
669        let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
670        let start = offset as usize;
671        let end = start + (self.max_degree * 4);
672        let bytes = &self.mmap[start..end];
673        bytemuck::cast_slice(bytes)
674    }
675
676    /// Computes distance between `query` and vector `idx`
677    fn distance_to(&self, query: &[T], idx: usize) -> f32 {
678        let elem_sz = std::mem::size_of::<T>();
679        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
680        let start = offset as usize;
681        let end = start + (self.dim * elem_sz);
682        let bytes = &self.mmap[start..end];
683        let vector: &[T] = bytemuck::cast_slice(bytes);
684        self.dist.eval(query, vector)
685    }
686
687    /// Gets a vector from the index
688    pub fn get_vector(&self, idx: usize) -> Vec<T> {
689        let elem_sz = std::mem::size_of::<T>();
690        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
691        let start = offset as usize;
692        let end = start + (self.dim * elem_sz);
693        let bytes = &self.mmap[start..end];
694        let vector: &[T] = bytemuck::cast_slice(bytes);
695        vector.to_vec()
696    }
697}
698
699/// Calculates the medoid using flat contiguous storage.
700fn calculate_medoid<T, D>(vectors: &FlatVectors<T>, dist: D) -> usize
701where
702    T: bytemuck::Pod + Copy + Send + Sync,
703    D: Distance<T> + Copy + Sync,
704{
705    let n = vectors.n;
706    let k = 8.min(n);
707    let mut rng = thread_rng();
708    let pivots: Vec<usize> = (0..k).map(|_| rng.gen_range(0..n)).collect();
709
710    let (best_idx, _best_score) = (0..n)
711        .into_par_iter()
712        .map(|i| {
713            let vi = vectors.row(i);
714            let score: f32 = pivots.iter().map(|&p| dist.eval(vi, vectors.row(p))).sum();
715            (i, score)
716        })
717        .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
718
719    best_idx
720}
721
722/// Build Vamana-like graph using:
723/// - flat contiguous vectors
724/// - per-worker reusable scratch
725/// - dense visited marks
726/// - ordered beams instead of BinaryHeap for build-time greedy search
727/// - batched parallel symmetrization-and-repruning
728fn build_vamana_graph<T, D>(
729    vectors: &FlatVectors<T>,
730    max_degree: usize,
731    build_beam_width: usize,
732    alpha: f32,
733    passes: usize,
734    extra_seeds: usize,
735    dist: D,
736    medoid_id: u32,
737) -> Vec<Vec<u32>>
738where
739    T: bytemuck::Pod + Copy + Send + Sync,
740    D: Distance<T> + Copy + Sync,
741{
742    let n = vectors.n;
743    let mut graph = vec![Vec::<u32>::new(); n];
744
745    // Random R-out directed graph bootstrap
746    {
747        let mut rng = thread_rng();
748        let target = max_degree.min(n.saturating_sub(1));
749
750        for i in 0..n {
751            let mut s = HashSet::with_capacity(target);
752            while s.len() < target {
753                let nb = rng.gen_range(0..n);
754                if nb != i {
755                    s.insert(nb as u32);
756                }
757            }
758            graph[i] = s.into_iter().collect();
759        }
760    }
761
762    let passes = passes.max(1);
763    let mut rng = thread_rng();
764
765    for pass_idx in 0..passes {
766        let pass_alpha = if passes == 1 {
767            alpha
768        } else if pass_idx == 0 {
769            1.0
770        } else {
771            alpha
772        };
773
774        let mut order: Vec<usize> = (0..n).collect();
775        order.shuffle(&mut rng);
776
777        let snapshot = &graph;
778
779        let new_graph: Vec<Vec<u32>> = order
780            .par_iter()
781            .map_init(
782                || BuildScratch::new(n, build_beam_width, max_degree, extra_seeds),
783                |scratch, &u| {
784                    scratch.candidates.clear();
785
786                    // Include current adjacency with distances
787                    for &nb in &snapshot[u] {
788                        let d = dist.eval(vectors.row(u), vectors.row(nb as usize));
789                        scratch.candidates.push((nb, d));
790                    }
791
792                    // Deduplicated seeds: medoid + distinct random starts
793                    scratch.seeds.clear();
794                    scratch.seeds.push(medoid_id as usize);
795                    let mut trng = thread_rng();
796                    while scratch.seeds.len() < 1 + extra_seeds {
797                        let s = trng.gen_range(0..n);
798                        if !scratch.seeds.contains(&s) {
799                            scratch.seeds.push(s);
800                        }
801                    }
802
803                    // Gather candidates from greedy search visited sets
804                    let seeds = scratch.seeds.clone();
805                    for start in seeds {
806                        greedy_search_visited_collect(
807                            vectors.row(u),
808                            vectors,
809                            snapshot,
810                            start,
811                            build_beam_width,
812                            dist,
813                            scratch,
814                        );
815
816                        for i in 0..scratch.visited_ids.len() {
817                            scratch
818                                .candidates
819                                .push((scratch.visited_ids[i], scratch.visited_dists[i]));
820                        }
821                    }
822
823                    // Deduplicate by id keeping best distance
824                    scratch.candidates.sort_by(|a, b| a.0.cmp(&b.0));
825                    scratch.candidates.dedup_by(|a, b| {
826                        if a.0 == b.0 {
827                            if a.1 < b.1 {
828                                *b = *a;
829                            }
830                            true
831                        } else {
832                            false
833                        }
834                    });
835
836                    prune_neighbors(
837                        u,
838                        &scratch.candidates,
839                        vectors,
840                        max_degree,
841                        pass_alpha,
842                        dist,
843                    )
844                },
845            )
846            .collect();
847
848        // Batched parallel symmetrization-and-repruning
849        let mut pos_of = vec![0usize; n];
850        for (pos, &u) in order.iter().enumerate() {
851            pos_of[u] = pos;
852        }
853
854        let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
855
856        graph = (0..n)
857            .into_par_iter()
858            .map(|u| {
859                let ng = &new_graph[pos_of[u]];
860                let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]];
861
862                let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
863                pool_ids.extend_from_slice(ng);
864                pool_ids.extend_from_slice(inc);
865                pool_ids.sort_unstable();
866                pool_ids.dedup();
867
868                let pool: Vec<(u32, f32)> = pool_ids
869                    .into_iter()
870                    .filter(|&id| id as usize != u)
871                    .map(|id| (id, dist.eval(vectors.row(u), vectors.row(id as usize))))
872                    .collect();
873
874                prune_neighbors(u, &pool, vectors, max_degree, pass_alpha, dist)
875            })
876            .collect();
877    }
878
879    // Final cleanup
880    graph
881        .into_par_iter()
882        .enumerate()
883        .map(|(u, neigh)| {
884            if neigh.len() <= max_degree {
885                return neigh;
886            }
887            let pool: Vec<(u32, f32)> = neigh
888                .iter()
889                .map(|&id| (id, dist.eval(vectors.row(u), vectors.row(id as usize))))
890                .collect();
891            prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
892        })
893        .collect()
894}
895
896/// Build-time greedy search:
897/// - dense visited marks instead of HashMap/HashSet
898/// - visited_ids + visited_dists instead of recomputing distances later
899/// - ordered beams instead of BinaryHeap
900///
901/// Output is written into `scratch.visited_ids` and `scratch.visited_dists`.
902fn greedy_search_visited_collect<T, D>(
903    query: &[T],
904    vectors: &FlatVectors<T>,
905    graph: &[Vec<u32>],
906    start_id: usize,
907    beam_width: usize,
908    dist: D,
909    scratch: &mut BuildScratch,
910) where
911    T: bytemuck::Pod + Copy + Send + Sync,
912    D: Distance<T> + Copy,
913{
914    scratch.reset_search();
915
916    let start_dist = dist.eval(query, vectors.row(start_id));
917    let start = Candidate {
918        dist: start_dist,
919        id: start_id as u32,
920    };
921
922    scratch.frontier.insert_unbounded(start);
923    scratch.work.insert_capped(start, beam_width);
924    scratch.mark_with_dist(start_id, start_dist);
925
926    while !scratch.frontier.is_empty() {
927        let best = scratch.frontier.best().unwrap();
928        if scratch.work.len() >= beam_width {
929            if let Some(worst) = scratch.work.worst() {
930                if best.dist >= worst.dist {
931                    break;
932                }
933            }
934        }
935
936        let cur = scratch.frontier.pop_best().unwrap();
937
938        for &nb in &graph[cur.id as usize] {
939            let nb_usize = nb as usize;
940            if scratch.is_marked(nb_usize) {
941                continue;
942            }
943
944            let d = dist.eval(query, vectors.row(nb_usize));
945            scratch.mark_with_dist(nb_usize, d);
946
947            let cand = Candidate { dist: d, id: nb };
948
949            if scratch.work.len() < beam_width {
950                scratch.work.insert_unbounded(cand);
951                scratch.frontier.insert_unbounded(cand);
952            } else if let Some(worst) = scratch.work.worst() {
953                if d < worst.dist {
954                    scratch.work.insert_capped(cand, beam_width);
955                    scratch.frontier.insert_unbounded(cand);
956                }
957            }
958        }
959    }
960}
961
962/// α-pruning
963fn prune_neighbors<T, D>(
964    node_id: usize,
965    candidates: &[(u32, f32)],
966    vectors: &FlatVectors<T>,
967    max_degree: usize,
968    alpha: f32,
969    dist: D,
970) -> Vec<u32>
971where
972    T: bytemuck::Pod + Copy + Send + Sync,
973    D: Distance<T> + Copy,
974{
975    if candidates.is_empty() {
976        return Vec::new();
977    }
978
979    let mut sorted = candidates.to_vec();
980    sorted.sort_by(|a, b| a.1.total_cmp(&b.1));
981
982    let mut pruned = Vec::<u32>::new();
983
984    for &(cand_id, cand_dist) in &sorted {
985        if cand_id as usize == node_id {
986            continue;
987        }
988        let mut ok = true;
989        for &sel in &pruned {
990            let d = dist.eval(vectors.row(cand_id as usize), vectors.row(sel as usize));
991            if alpha * d <= cand_dist {
992                ok = false;
993                break;
994            }
995        }
996        if ok {
997            pruned.push(cand_id);
998            if pruned.len() >= max_degree {
999                break;
1000            }
1001        }
1002    }
1003
1004    for &(cand_id, _) in &sorted {
1005        if cand_id as usize == node_id {
1006            continue;
1007        }
1008        if !pruned.contains(&cand_id) {
1009            pruned.push(cand_id);
1010            if pruned.len() >= max_degree {
1011                break;
1012            }
1013        }
1014    }
1015
1016    pruned
1017}
1018
1019fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
1020    let mut indeg = vec![0usize; n];
1021    for (pos, _u) in order.iter().enumerate() {
1022        for &v in &new_graph[pos] {
1023            indeg[v as usize] += 1;
1024        }
1025    }
1026
1027    let mut off = vec![0usize; n + 1];
1028    for i in 0..n {
1029        off[i + 1] = off[i] + indeg[i];
1030    }
1031
1032    let mut cur = off.clone();
1033    let mut incoming_flat = vec![0u32; off[n]];
1034    for (pos, &u) in order.iter().enumerate() {
1035        for &v in &new_graph[pos] {
1036            let idx = cur[v as usize];
1037            incoming_flat[idx] = u as u32;
1038            cur[v as usize] += 1;
1039        }
1040    }
1041    (incoming_flat, off)
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046    use super::*;
1047    use anndists::dist::{DistCosine, DistL2};
1048    use rand::Rng;
1049    use std::fs;
1050
1051    fn euclid(a: &[f32], b: &[f32]) -> f32 {
1052        a.iter()
1053            .zip(b)
1054            .map(|(x, y)| (x - y) * (x - y))
1055            .sum::<f32>()
1056            .sqrt()
1057    }
1058
1059    #[test]
1060    fn test_small_index_l2() {
1061        let path = "test_small_l2.db";
1062        let _ = fs::remove_file(path);
1063
1064        let vectors = vec![
1065            vec![0.0, 0.0],
1066            vec![1.0, 0.0],
1067            vec![0.0, 1.0],
1068            vec![1.0, 1.0],
1069            vec![0.5, 0.5],
1070        ];
1071
1072        let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1073
1074        let q = vec![0.1, 0.1];
1075        let nns = index.search(&q, 3, 8);
1076        assert_eq!(nns.len(), 3);
1077
1078        let v = index.get_vector(nns[0] as usize);
1079        assert!(euclid(&q, &v) < 1.0);
1080
1081        let _ = fs::remove_file(path);
1082    }
1083
1084    #[test]
1085    fn test_cosine() {
1086        let path = "test_cosine.db";
1087        let _ = fs::remove_file(path);
1088
1089        let vectors = vec![
1090            vec![1.0, 0.0, 0.0],
1091            vec![0.0, 1.0, 0.0],
1092            vec![0.0, 0.0, 1.0],
1093            vec![1.0, 1.0, 0.0],
1094            vec![1.0, 0.0, 1.0],
1095        ];
1096
1097        let index =
1098            DiskANN::<f32, DistCosine>::build_index_default(&vectors, DistCosine, path).unwrap();
1099
1100        let q = vec![2.0, 0.0, 0.0];
1101        let nns = index.search(&q, 2, 8);
1102        assert_eq!(nns.len(), 2);
1103
1104        let v = index.get_vector(nns[0] as usize);
1105        let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
1106        let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1107        let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
1108        let cos = dot / (n1 * n2);
1109        assert!(cos > 0.7);
1110
1111        let _ = fs::remove_file(path);
1112    }
1113
1114    #[test]
1115    fn test_persistence_and_open() {
1116        let path = "test_persist.db";
1117        let _ = fs::remove_file(path);
1118
1119        let vectors = vec![
1120            vec![0.0, 0.0],
1121            vec![1.0, 0.0],
1122            vec![0.0, 1.0],
1123            vec![1.0, 1.0],
1124        ];
1125
1126        {
1127            let _idx =
1128                DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1129        }
1130
1131        let idx2 = DiskANN::<f32, DistL2>::open_index_default_metric(path).unwrap();
1132        assert_eq!(idx2.num_vectors, 4);
1133        assert_eq!(idx2.dim, 2);
1134
1135        let q = vec![0.9, 0.9];
1136        let res = idx2.search(&q, 2, 8);
1137        assert_eq!(res[0], 3);
1138
1139        let _ = fs::remove_file(path);
1140    }
1141
1142    #[test]
1143    fn test_grid_connectivity() {
1144        let path = "test_grid.db";
1145        let _ = fs::remove_file(path);
1146
1147        let mut vectors = Vec::new();
1148        for i in 0..5 {
1149            for j in 0..5 {
1150                vectors.push(vec![i as f32, j as f32]);
1151            }
1152        }
1153
1154        let index = DiskANN::<f32, DistL2>::build_index_with_params(
1155            &vectors,
1156            DistL2,
1157            path,
1158            DiskAnnParams {
1159                max_degree: 4,
1160                build_beam_width: 64,
1161                alpha: 1.5,
1162                passes: DISKANN_DEFAULT_PASSES,
1163                extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1164            },
1165        )
1166        .unwrap();
1167
1168        for target in 0..vectors.len() {
1169            let q = &vectors[target];
1170            let nns = index.search(q, 10, 32);
1171            if !nns.contains(&(target as u32)) {
1172                let v = index.get_vector(nns[0] as usize);
1173                assert!(euclid(q, &v) < 2.0);
1174            }
1175            for &nb in nns.iter().take(5) {
1176                let v = index.get_vector(nb as usize);
1177                assert!(euclid(q, &v) < 5.0);
1178            }
1179        }
1180
1181        let _ = fs::remove_file(path);
1182    }
1183
1184    #[test]
1185    fn test_medium_random() {
1186        let path = "test_medium.db";
1187        let _ = fs::remove_file(path);
1188
1189        let n = 200usize;
1190        let d = 32usize;
1191        let mut rng = rand::thread_rng();
1192        let vectors: Vec<Vec<f32>> = (0..n)
1193            .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
1194            .collect();
1195
1196        let index = DiskANN::<f32, DistL2>::build_index_with_params(
1197            &vectors,
1198            DistL2,
1199            path,
1200            DiskAnnParams {
1201                max_degree: 32,
1202                build_beam_width: 128,
1203                alpha: 1.2,
1204                passes: DISKANN_DEFAULT_PASSES,
1205                extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1206            },
1207        )
1208        .unwrap();
1209
1210        let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1211        let res = index.search(&q, 10, 64);
1212        assert_eq!(res.len(), 10);
1213
1214        let dists: Vec<f32> = res
1215            .iter()
1216            .map(|&id| {
1217                let v = index.get_vector(id as usize);
1218                euclid(&q, &v)
1219            })
1220            .collect();
1221        let mut sorted = dists.clone();
1222        sorted.sort_by(|a, b| a.total_cmp(b));
1223        assert_eq!(dists, sorted);
1224
1225        let _ = fs::remove_file(path);
1226    }
1227}