Skip to main content

rust_diskann/
lib.rs

1//! # DiskAnn (generic over `anndists::Distance<T>`)
2//!
3//! An on-disk DiskANN library that:
4//! - Builds a Vamana-style graph (greedy + α-pruning) in memory
5//! - Writes vectors + fixed-degree adjacency to a single file
6//! - Memory-maps the file for low-overhead reads
7//! - Is **generic over any Distance<T>** from `anndists` (e.g. L2 on `f32`, Cosine on `f32`,
8//!   Hamming on `u64`, …)
9//!
10//! ## Example (f32 + L2)
11//! ```no_run
12//! use anndists::dist::DistL2;
13//! use rust_diskann::{DiskANN, DiskAnnParams};
14//!
15//! let vectors: Vec<Vec<f32>> = vec![vec![0.0; 128]; 1000];
16//! let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, "index.db").unwrap();
17//!
18//! let q = vec![0.0; 128];
19//! let nns = index.search(&q, 10, 64);
20//! ```
21//!
22//! ## Example (u64 + Hamming)
23//! ```no_run
24//! use anndists::dist::DistHamming;
25//! use rust_diskann::{DiskANN, DiskAnnParams};
26//! let index: Vec<Vec<u64>> = vec![vec![0u64; 128]; 1000];
27//! let idx = DiskANN::<u64, DistHamming>::build_index_default(&index, DistHamming, "mh.db").unwrap();
28//! let q = vec![0u64; 128];
29//! let _ = idx.search(&q, 10, 64);
30//! ```
31//!
32//! ## File Layout
33//! [ metadata_len:u64 ][ metadata (bincode) ][ padding up to vectors_offset ]
34//! [ vectors (num * dim * T) ][ adjacency (num * max_degree * u32) ]
35//!
36//! `vectors_offset` is a fixed 1 MiB gap by default.
37
38use anndists::prelude::Distance;
39use memmap2::Mmap;
40use rand::{prelude::*, thread_rng};
41use rayon::prelude::*;
42use serde::{Deserialize, Serialize};
43use std::cmp::{Ordering, Reverse};
44use std::collections::{BinaryHeap, HashSet};
45use std::fs::OpenOptions;
46use std::io::{Read, Seek, SeekFrom, Write};
47use std::marker::PhantomData;
48use thiserror::Error;
49
50/// Padding sentinel for adjacency slots (avoid colliding with node 0).
51const PAD_U32: u32 = u32::MAX;
52
53/// Defaults for in-memory DiskANN builds
54pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
55pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
56pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
57/// Default number of refinement passes during graph build
58pub const DISKANN_DEFAULT_PASSES: usize = 1;
59/// Default number of extra random seeds per node per pass during graph build
60pub const DISKANN_DEFAULT_EXTRA_SEEDS: usize = 1;
61
62/// Practical DiskANN-style slack before reverse-neighbor re-pruning.
63/// Legacy C++ DiskANN allows reverse lists to grow to about GRAPH_SLACK_FACTOR * R
64/// before triggering prune, instead of pruning immediately at R.
65const GRAPH_SLACK_FACTOR: f32 = 1.3;
66
67
68/// Number of nodes processed together in one micro-batch during graph build.
69///
70/// Smaller values:
71/// - are closer to true incremental insertion
72/// - usually improve faithfulness to practical Vamana behavior
73/// - but are slower
74///
75/// Larger values:
76/// - are faster
77/// - but behave more like a batched graph rebuild
78/// Recommended starting values:
79/// - 128 for better quality
80/// - 256 for a balanced tradeoff
81/// - 512 for faster builds
82
83const MICRO_BATCH_CHUNK_SIZE: usize = 256;
84
85/// Optional bag of knobs if you want to override just a few.
86#[derive(Clone, Copy, Debug)]
87pub struct DiskAnnParams {
88    pub max_degree: usize,
89    pub build_beam_width: usize,
90    pub alpha: f32,
91    /// Number of refinement passes over the graph (>=1).
92    pub passes: usize,
93    /// Extra random seeds per node during each pass (>=0).
94    pub extra_seeds: usize,
95}
96
97impl Default for DiskAnnParams {
98    fn default() -> Self {
99        Self {
100            max_degree: DISKANN_DEFAULT_MAX_DEGREE,
101            build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
102            alpha: DISKANN_DEFAULT_ALPHA,
103            passes: DISKANN_DEFAULT_PASSES,
104            extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
105        }
106    }
107}
108
109/// Custom error type for DiskAnn operations
110#[derive(Debug, Error)]
111pub enum DiskAnnError {
112    /// Represents I/O errors during file operations
113    #[error("I/O error: {0}")]
114    Io(#[from] std::io::Error),
115
116    /// Represents serialization/deserialization errors
117    #[error("Serialization error: {0}")]
118    Bincode(#[from] bincode::Error),
119
120    /// Represents index-specific errors
121    #[error("Index error: {0}")]
122    IndexError(String),
123}
124
125/// Internal metadata structure stored in the index file
126#[derive(Serialize, Deserialize, Debug)]
127struct Metadata {
128    dim: usize,
129    num_vectors: usize,
130    max_degree: usize,
131    medoid_id: u32,
132    vectors_offset: u64,
133    adjacency_offset: u64,
134    elem_size: u8,
135    distance_name: String,
136}
137
138/// Candidate for search/frontier queues
139#[derive(Clone, Copy, Debug)]
140struct Candidate {
141    dist: f32,
142    id: u32,
143}
144impl PartialEq for Candidate {
145    fn eq(&self, other: &Self) -> bool {
146        self.id == other.id && self.dist.to_bits() == other.dist.to_bits()
147    }
148}
149impl Eq for Candidate {}
150impl PartialOrd for Candidate {
151    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
152        Some(
153            self.dist
154                .total_cmp(&other.dist)
155                .then_with(|| self.id.cmp(&other.id)),
156        )
157    }
158}
159impl Ord for Candidate {
160    fn cmp(&self, other: &Self) -> Ordering {
161        self.partial_cmp(other).unwrap_or(Ordering::Equal)
162    }
163}
164
165/// Flat contiguous matrix used during build to improve cache locality.
166///
167/// Rows are stored consecutively in `data`, row-major.
168#[derive(Clone, Debug)]
169struct FlatVectors<T> {
170    data: Vec<T>,
171    dim: usize,
172    n: usize,
173}
174
175impl<T: Copy> FlatVectors<T> {
176    fn from_vecs(vectors: &[Vec<T>]) -> Result<Self, DiskAnnError> {
177        if vectors.is_empty() {
178            return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
179        }
180        let dim = vectors[0].len();
181        for (i, v) in vectors.iter().enumerate() {
182            if v.len() != dim {
183                return Err(DiskAnnError::IndexError(format!(
184                    "Vector {} has dimension {} but expected {}",
185                    i,
186                    v.len(),
187                    dim
188                )));
189            }
190        }
191
192        let n = vectors.len();
193        let mut data = Vec::with_capacity(n * dim);
194        for v in vectors {
195            data.extend_from_slice(v);
196        }
197
198        Ok(Self { data, dim, n })
199    }
200
201    #[inline]
202    fn row(&self, idx: usize) -> &[T] {
203        let start = idx * self.dim;
204        let end = start + self.dim;
205        &self.data[start..end]
206    }
207}
208
209/// Small ordered beam structure used only during build-time greedy search.
210///
211/// It keeps elements in **descending** distance order:
212/// - index 0 is the worst element
213/// - last element is the best element
214///
215/// This makes:
216/// - `best()` cheap via `last()`
217/// - `worst()` cheap via `first()`
218/// - capped beam maintenance simple
219#[derive(Default, Debug)]
220struct OrderedBeam {
221    items: Vec<Candidate>,
222}
223
224impl OrderedBeam {
225    #[inline]
226    fn clear(&mut self) {
227        self.items.clear();
228    }
229
230    #[inline]
231    fn len(&self) -> usize {
232        self.items.len()
233    }
234
235    #[inline]
236    fn is_empty(&self) -> bool {
237        self.items.is_empty()
238    }
239
240    #[inline]
241    fn best(&self) -> Option<Candidate> {
242        self.items.last().copied()
243    }
244
245    #[inline]
246    fn worst(&self) -> Option<Candidate> {
247        self.items.first().copied()
248    }
249
250    #[inline]
251    fn pop_best(&mut self) -> Option<Candidate> {
252        self.items.pop()
253    }
254
255    #[inline]
256    fn reserve(&mut self, cap: usize) {
257        if self.items.capacity() < cap {
258            self.items.reserve(cap - self.items.capacity());
259        }
260    }
261
262    #[inline]
263    fn insert_unbounded(&mut self, cand: Candidate) {
264        let pos = self.items.partition_point(|x| {
265            x.dist > cand.dist || (x.dist.to_bits() == cand.dist.to_bits() && x.id > cand.id)
266        });
267        self.items.insert(pos, cand);
268    }
269
270    #[inline]
271    fn insert_capped(&mut self, cand: Candidate, cap: usize) {
272        if cap == 0 {
273            return;
274        }
275
276        if self.items.len() < cap {
277            self.insert_unbounded(cand);
278            return;
279        }
280
281        // Since items[0] is the worst, only insert if the new candidate is better.
282        let worst = self.items[0];
283        if cand.dist >= worst.dist {
284            return;
285        }
286
287        self.insert_unbounded(cand);
288
289        if self.items.len() > cap {
290            self.items.remove(0);
291        }
292    }
293}
294
295/// Reusable scratch buffers for build-time greedy search.
296/// One instance is created per Rayon worker via `map_init`, so allocations are reused
297/// across many nodes in the build.
298#[derive(Debug)]
299struct BuildScratch {
300    marks: Vec<u32>,
301    epoch: u32,
302
303    visited_ids: Vec<u32>,
304    visited_dists: Vec<f32>,
305
306    frontier: OrderedBeam,
307    work: OrderedBeam,
308
309    seeds: Vec<usize>,
310    candidates: Vec<(u32, f32)>,
311}
312
313impl BuildScratch {
314    fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
315        Self {
316            marks: vec![0u32; n],
317            epoch: 1,
318            visited_ids: Vec::with_capacity(beam_width * 4),
319            visited_dists: Vec::with_capacity(beam_width * 4),
320            frontier: {
321                let mut b = OrderedBeam::default();
322                b.reserve(beam_width * 2);
323                b
324            },
325            work: {
326                let mut b = OrderedBeam::default();
327                b.reserve(beam_width * 2);
328                b
329            },
330            seeds: Vec::with_capacity(1 + extra_seeds),
331            candidates: Vec::with_capacity(beam_width * (4 + extra_seeds) + max_degree * 2),
332        }
333    }
334
335    #[inline]
336    fn reset_search(&mut self) {
337        self.epoch = self.epoch.wrapping_add(1);
338        if self.epoch == 0 {
339            self.marks.fill(0);
340            self.epoch = 1;
341        }
342        self.visited_ids.clear();
343        self.visited_dists.clear();
344        self.frontier.clear();
345        self.work.clear();
346    }
347
348    #[inline]
349    fn is_marked(&self, idx: usize) -> bool {
350        self.marks[idx] == self.epoch
351    }
352
353    #[inline]
354    fn mark_with_dist(&mut self, idx: usize, dist: f32) {
355        self.marks[idx] = self.epoch;
356        self.visited_ids.push(idx as u32);
357        self.visited_dists.push(dist);
358    }
359}
360
361#[derive(Debug)]
362struct IncrementalInsertScratch {
363    build: BuildScratch,
364}
365
366impl IncrementalInsertScratch {
367    fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
368        Self {
369            build: BuildScratch::new(n, beam_width, max_degree, extra_seeds),
370        }
371    }
372}
373
374/// Main struct representing a DiskANN index (generic over vector element `T` and distance `D`)
375pub struct DiskANN<T, D>
376where
377    T: bytemuck::Pod + Copy + Send + Sync + 'static,
378    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
379{
380    /// Dimensionality of vectors in the index
381    pub dim: usize,
382    /// Number of vectors in the index
383    pub num_vectors: usize,
384    /// Maximum number of edges per node
385    pub max_degree: usize,
386    /// Informational: type name of the distance (from metadata)
387    pub distance_name: String,
388
389    /// ID of the medoid (used as entry point)
390    medoid_id: u32,
391    // Offsets
392    vectors_offset: u64,
393    adjacency_offset: u64,
394
395    /// Memory-mapped file
396    mmap: Mmap,
397
398    /// The distance strategy
399    dist: D,
400
401    /// keep `T` in the type so the compiler knows about it
402    _phantom: PhantomData<T>,
403}
404
405// constructors
406
407impl<T, D> DiskANN<T, D>
408where
409    T: bytemuck::Pod + Copy + Send + Sync + 'static,
410    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
411{
412    /// Build with default parameters: (M=64, L=128, alpha=1.2, passes=2, extra_seeds=2).
413    pub fn build_index_default(
414        vectors: &[Vec<T>],
415        dist: D,
416        file_path: &str,
417    ) -> Result<Self, DiskAnnError> {
418        Self::build_index(
419            vectors,
420            DISKANN_DEFAULT_MAX_DEGREE,
421            DISKANN_DEFAULT_BUILD_BEAM,
422            DISKANN_DEFAULT_ALPHA,
423            DISKANN_DEFAULT_PASSES,
424            DISKANN_DEFAULT_EXTRA_SEEDS,
425            dist,
426            file_path,
427        )
428    }
429
430    /// Build with a `DiskAnnParams` bundle.
431    pub fn build_index_with_params(
432        vectors: &[Vec<T>],
433        dist: D,
434        file_path: &str,
435        p: DiskAnnParams,
436    ) -> Result<Self, DiskAnnError> {
437        Self::build_index(
438            vectors,
439            p.max_degree,
440            p.build_beam_width,
441            p.alpha,
442            p.passes,
443            p.extra_seeds,
444            dist,
445            file_path,
446        )
447    }
448
449    /// Opens an existing index file, supplying the distance strategy explicitly.
450    pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
451        let mut file = OpenOptions::new().read(true).write(false).open(path)?;
452
453        // Read metadata length
454        let mut buf8 = [0u8; 8];
455        file.seek(SeekFrom::Start(0))?;
456        file.read_exact(&mut buf8)?;
457        let md_len = u64::from_le_bytes(buf8);
458
459        // Read metadata
460        let mut md_bytes = vec![0u8; md_len as usize];
461        file.read_exact(&mut md_bytes)?;
462        let metadata: Metadata = bincode::deserialize(&md_bytes)?;
463
464        let mmap = unsafe { memmap2::Mmap::map(&file)? };
465
466        // Validate element size vs T
467        let want = std::mem::size_of::<T>() as u8;
468        if metadata.elem_size != want {
469            return Err(DiskAnnError::IndexError(format!(
470                "element size mismatch: file has {}B, T is {}B",
471                metadata.elem_size, want
472            )));
473        }
474
475        // Optional sanity/logging: warn if type differs from recorded name
476        let expected = std::any::type_name::<D>();
477        if metadata.distance_name != expected {
478            eprintln!(
479                "Warning: index recorded distance `{}` but you opened with `{}`",
480                metadata.distance_name, expected
481            );
482        }
483
484        Ok(Self {
485            dim: metadata.dim,
486            num_vectors: metadata.num_vectors,
487            max_degree: metadata.max_degree,
488            distance_name: metadata.distance_name,
489            medoid_id: metadata.medoid_id,
490            vectors_offset: metadata.vectors_offset,
491            adjacency_offset: metadata.adjacency_offset,
492            mmap,
493            dist,
494            _phantom: PhantomData,
495        })
496    }
497}
498
499/// Extra sugar when your distance type implements `Default`.
500impl<T, D> DiskANN<T, D>
501where
502    T: bytemuck::Pod + Copy + Send + Sync + 'static,
503    D: Distance<T> + Default + Send + Sync + Copy + Clone + 'static,
504{
505    /// Build with default params **and** `D::default()` metric.
506    pub fn build_index_default_metric(
507        vectors: &[Vec<T>],
508        file_path: &str,
509    ) -> Result<Self, DiskAnnError> {
510        Self::build_index_default(vectors, D::default(), file_path)
511    }
512
513    /// Open an index using `D::default()` as the distance (matches what you built with).
514    pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
515        Self::open_index_with(path, D::default())
516    }
517}
518
519impl<T, D> DiskANN<T, D>
520where
521    T: bytemuck::Pod + Copy + Send + Sync + 'static,
522    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
523{
524    /// Builds a new index from provided vectors
525    ///
526    /// # Arguments
527    /// * `vectors` - The vectors to index (slice of Vec<T>)
528    /// * `max_degree` - Maximum edges per node (M ~ 24-64+)
529    /// * `build_beam_width` - Construction L (e.g., 128-400)
530    /// * `alpha` - Pruning parameter (1.2–2.0)
531    /// * `passes` - Refinement passes over the graph (>=1)
532    /// * `extra_seeds` - Extra random seeds per node per pass (>=0)
533    /// * `dist` - Any `anndists::Distance<T>`
534    /// * `file_path` - Path of index file
535    pub fn build_index(
536        vectors: &[Vec<T>],
537        max_degree: usize,
538        build_beam_width: usize,
539        alpha: f32,
540        passes: usize,
541        extra_seeds: usize,
542        dist: D,
543        file_path: &str,
544    ) -> Result<Self, DiskAnnError> {
545        let flat = FlatVectors::from_vecs(vectors)?;
546
547        let num_vectors = flat.n;
548        let dim = flat.dim;
549
550        let mut file = OpenOptions::new()
551            .create(true)
552            .write(true)
553            .read(true)
554            .truncate(true)
555            .open(file_path)?;
556
557        // Reserve space for metadata (we'll write it after data)
558        let vectors_offset = 1024 * 1024;
559        assert_eq!(
560            (vectors_offset as usize) % std::mem::align_of::<T>(),
561            0,
562            "vectors_offset must be aligned for T"
563        );
564
565        let elem_sz = std::mem::size_of::<T>() as u64;
566        let total_vector_bytes = (num_vectors as u64) * (dim as u64) * elem_sz;
567
568        // Write vectors contiguous
569        file.seek(SeekFrom::Start(vectors_offset as u64))?;
570        file.write_all(bytemuck::cast_slice::<T, u8>(&flat.data))?;
571
572        // Compute medoid using flat storage
573        let medoid_id = calculate_medoid(&flat, dist);
574
575        // Build graph
576        let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
577        let graph = build_vamana_graph(
578            &flat,
579            max_degree,
580            build_beam_width,
581            alpha,
582            passes,
583            extra_seeds,
584            dist,
585            medoid_id as u32,
586        );
587
588        // Write adjacency lists
589        file.seek(SeekFrom::Start(adjacency_offset))?;
590        for neighbors in &graph {
591            let mut padded = neighbors.clone();
592            padded.resize(max_degree, PAD_U32);
593            let bytes = bytemuck::cast_slice::<u32, u8>(&padded);
594            file.write_all(bytes)?;
595        }
596
597        // Write metadata
598        let metadata = Metadata {
599            dim,
600            num_vectors,
601            max_degree,
602            medoid_id: medoid_id as u32,
603            vectors_offset: vectors_offset as u64,
604            adjacency_offset,
605            elem_size: std::mem::size_of::<T>() as u8,
606            distance_name: std::any::type_name::<D>().to_string(),
607        };
608
609        let md_bytes = bincode::serialize(&metadata)?;
610        file.seek(SeekFrom::Start(0))?;
611        let md_len = md_bytes.len() as u64;
612        file.write_all(&md_len.to_le_bytes())?;
613        file.write_all(&md_bytes)?;
614        file.sync_all()?;
615
616        // Memory map the file
617        let mmap = unsafe { memmap2::Mmap::map(&file)? };
618
619        Ok(Self {
620            dim,
621            num_vectors,
622            max_degree,
623            distance_name: metadata.distance_name,
624            medoid_id: metadata.medoid_id,
625            vectors_offset: metadata.vectors_offset,
626            adjacency_offset: metadata.adjacency_offset,
627            mmap,
628            dist,
629            _phantom: PhantomData,
630        })
631    }
632
633    /// Searches the index for nearest neighbors using a best-first beam search.
634    /// Termination rule: continue while the best frontier can still improve the worst in working set.
635    pub fn search_with_dists(&self, query: &[T], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
636        assert_eq!(
637            query.len(),
638            self.dim,
639            "Query dim {} != index dim {}",
640            query.len(),
641            self.dim
642        );
643
644        let mut visited = HashSet::new();
645        let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
646        let mut w: BinaryHeap<Candidate> = BinaryHeap::new();
647
648        let start_dist = self.distance_to(query, self.medoid_id as usize);
649        let start = Candidate {
650            dist: start_dist,
651            id: self.medoid_id,
652        };
653        frontier.push(Reverse(start));
654        w.push(start);
655        visited.insert(self.medoid_id);
656
657        while let Some(Reverse(best)) = frontier.peek().copied() {
658            if w.len() >= beam_width {
659                if let Some(worst) = w.peek() {
660                    if best.dist >= worst.dist {
661                        break;
662                    }
663                }
664            }
665            let Reverse(current) = frontier.pop().unwrap();
666
667            for &nb in self.get_neighbors(current.id) {
668                if nb == PAD_U32 {
669                    continue;
670                }
671                if !visited.insert(nb) {
672                    continue;
673                }
674
675                let d = self.distance_to(query, nb as usize);
676                let cand = Candidate { dist: d, id: nb };
677
678                if w.len() < beam_width {
679                    w.push(cand);
680                    frontier.push(Reverse(cand));
681                } else if d < w.peek().unwrap().dist {
682                    w.pop();
683                    w.push(cand);
684                    frontier.push(Reverse(cand));
685                }
686            }
687        }
688
689        let mut results: Vec<_> = w.into_vec();
690        results.sort_by(|a, b| a.dist.total_cmp(&b.dist));
691        results.truncate(k);
692        results.into_iter().map(|c| (c.id, c.dist)).collect()
693    }
694
695    /// search but only return neighbor ids
696    pub fn search(&self, query: &[T], k: usize, beam_width: usize) -> Vec<u32> {
697        self.search_with_dists(query, k, beam_width)
698            .into_iter()
699            .map(|(id, _dist)| id)
700            .collect()
701    }
702
703    /// Gets the neighbors of a node from the (fixed-degree) adjacency region
704    fn get_neighbors(&self, node_id: u32) -> &[u32] {
705        let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
706        let start = offset as usize;
707        let end = start + (self.max_degree * 4);
708        let bytes = &self.mmap[start..end];
709        bytemuck::cast_slice(bytes)
710    }
711
712    /// Computes distance between `query` and vector `idx`
713    fn distance_to(&self, query: &[T], idx: usize) -> f32 {
714        let elem_sz = std::mem::size_of::<T>();
715        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
716        let start = offset as usize;
717        let end = start + (self.dim * elem_sz);
718        let bytes = &self.mmap[start..end];
719        let vector: &[T] = bytemuck::cast_slice(bytes);
720        self.dist.eval(query, vector)
721    }
722
723    /// Gets a vector from the index
724    pub fn get_vector(&self, idx: usize) -> Vec<T> {
725        let elem_sz = std::mem::size_of::<T>();
726        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
727        let start = offset as usize;
728        let end = start + (self.dim * elem_sz);
729        let bytes = &self.mmap[start..end];
730        let vector: &[T] = bytemuck::cast_slice(bytes);
731        vector.to_vec()
732    }
733}
734
735/// Calculates the medoid using flat contiguous storage.
736fn calculate_medoid<T, D>(vectors: &FlatVectors<T>, dist: D) -> usize
737where
738    T: bytemuck::Pod + Copy + Send + Sync,
739    D: Distance<T> + Copy + Sync,
740{
741    let n = vectors.n;
742    let k = 8.min(n);
743    let mut rng = thread_rng();
744    let pivots: Vec<usize> = (0..k).map(|_| rng.gen_range(0..n)).collect();
745
746    let (best_idx, _best_score) = (0..n)
747        .into_par_iter()
748        .map(|i| {
749            let vi = vectors.row(i);
750            let score: f32 = pivots.iter().map(|&p| dist.eval(vi, vectors.row(p))).sum();
751            (i, score)
752        })
753        .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
754
755    best_idx
756}
757
758fn dedup_keep_best_by_id_in_place(cands: &mut Vec<(u32, f32)>) {
759    if cands.is_empty() {
760        return;
761    }
762
763    cands.sort_by(|a, b| {
764        a.0.cmp(&b.0)
765            .then_with(|| a.1.total_cmp(&b.1))
766    });
767
768    let mut write = 0usize;
769    for read in 0..cands.len() {
770        if write == 0 || cands[read].0 != cands[write - 1].0 {
771            cands[write] = cands[read];
772            write += 1;
773        }
774    }
775    cands.truncate(write);
776}
777
778/// Merge one micro-batch of newly computed outgoing neighbor lists back into the graph.
779/// Mark all nodes in the chunk as affected.
780/// Count reverse-edge insertions implied by the chunk.
781/// Build a CSR-style flat incoming buffer.
782/// Commit chunk outgoing lists.
783/// Re-prune only the affected nodes.
784fn merge_chunk_updates_into_graph_reuse<T, D>(
785    graph: &mut [Vec<u32>],
786    chunk_nodes: &[usize],
787    chunk_pruned: &[Vec<u32>],
788    vectors: &FlatVectors<T>,
789    max_degree: usize,
790    slack_limit: usize,
791    alpha: f32,
792    dist: D,
793    merge: &mut MergeScratch,
794) where
795    T: bytemuck::Pod + Copy + Send + Sync,
796    D: Distance<T> + Copy + Sync,
797{
798    merge.reset();
799
800    // Mark chunk nodes as affected.
801    for &u in chunk_nodes {
802        merge.mark_affected(u);
803    }
804
805    // Count reverse-edge insertions and mark touched destinations.
806    let mut total_incoming = 0usize;
807
808    for (local_idx, &u) in chunk_nodes.iter().enumerate() {
809        for &dst in &chunk_pruned[local_idx] {
810            let dst_usize = dst as usize;
811            if dst_usize == u {
812                continue;
813            }
814
815            merge.mark_affected(dst_usize);
816            merge.incoming_counts[dst_usize] += 1;
817            total_incoming += 1;
818        }
819    }
820
821    // Build CSR offsets only for affected nodes.
822    merge.affected_nodes.sort_unstable();
823
824    let mut running = 0usize;
825    for &u in &merge.affected_nodes {
826        merge.incoming_offsets[u] = running;
827        running += merge.incoming_counts[u];
828        merge.incoming_offsets[u + 1] = running;
829    }
830
831    merge.incoming_flat.resize(total_incoming, PAD_U32);
832
833    // Initialize write cursors.
834    for &u in &merge.affected_nodes {
835        merge.incoming_write[u] = merge.incoming_offsets[u];
836    }
837
838    // Fill CSR incoming buffer.
839    for (local_idx, &u) in chunk_nodes.iter().enumerate() {
840        for &dst in &chunk_pruned[local_idx] {
841            let dst_usize = dst as usize;
842            if dst_usize == u {
843                continue;
844            }
845
846            let pos = merge.incoming_write[dst_usize];
847            merge.incoming_flat[pos] = u as u32;
848            merge.incoming_write[dst_usize] += 1;
849        }
850    }
851
852    // Commit chunk outgoing lists first.
853    for (local_idx, &u) in chunk_nodes.iter().enumerate() {
854        graph[u] = chunk_pruned[local_idx].clone();
855    }
856
857    // Hybrid slack-aware microbatch merge:
858    // keep merged lists under slack, reprune only overflowed ones.
859    let affected = merge.affected_nodes.clone();
860
861    let updated_pairs: Vec<(usize, Vec<u32>)> = affected
862        .into_par_iter()
863        .map(|u| {
864            let start = merge.incoming_offsets[u];
865            let end = merge.incoming_offsets[u + 1];
866
867            let mut ids: Vec<u32> = Vec::with_capacity(graph[u].len() + (end - start));
868
869            // Current adjacency after chunk commit.
870            ids.extend_from_slice(&graph[u]);
871
872            // Reverse insertions from this chunk.
873            if start < end {
874                ids.extend_from_slice(&merge.incoming_flat[start..end]);
875            }
876
877            // Remove self-loops / padding.
878            ids.retain(|&id| id != PAD_U32 && id as usize != u);
879
880            // Deduplicate.
881            ids.sort_unstable();
882            ids.dedup();
883
884            if ids.is_empty() {
885                return (u, Vec::new());
886            }
887
888            // Under slack: keep as-is.
889            if ids.len() <= slack_limit {
890                return (u, ids);
891            }
892
893            // Overflow: score and prune back to max_degree.
894            let mut pool = Vec::<(u32, f32)>::with_capacity(ids.len());
895            for id in ids {
896                let d = dist.eval(vectors.row(u), vectors.row(id as usize));
897                pool.push((id, d));
898            }
899
900            let pruned = prune_neighbors(u, &pool, vectors, max_degree, alpha, dist);
901            (u, pruned)
902        })
903        .collect();
904
905    for (u, neigh) in updated_pairs {
906        graph[u] = neigh;
907    }
908
909    // Cleanup touched metadata.
910    for &u in &merge.affected_nodes {
911        merge.incoming_counts[u] = 0;
912        merge.incoming_offsets[u + 1] = 0;
913    }
914}
915
916/// Reusable scratch buffers for micro-batch merge.
917/// This avoids rebuilding `Vec<Vec<u32>>` for incoming reverse edges on every chunk.
918/// Instead, incoming reverse edges are accumulated into a CSR-like flat buffer.
919///
920/// Layout:
921/// - incoming_counts[u] = number of reverse edges targeting node u in this chunk
922/// - incoming_offsets[u]..incoming_offsets[u+1] is u's segment in incoming_flat
923/// - affected_nodes stores exactly the nodes touched in this chunk, so we do not
924///   scan all `n` nodes during merge.
925#[derive(Debug)]
926struct MergeScratch {
927    incoming_counts: Vec<usize>,
928    incoming_offsets: Vec<usize>,
929    incoming_write: Vec<usize>,
930    incoming_flat: Vec<u32>,
931
932    affected_marks: Vec<u32>,
933    affected_epoch: u32,
934    affected_nodes: Vec<usize>,
935}
936
937impl MergeScratch {
938    fn new(n: usize) -> Self {
939        Self {
940            incoming_counts: vec![0usize; n],
941            incoming_offsets: vec![0usize; n + 1],
942            incoming_write: vec![0usize; n],
943            incoming_flat: Vec::new(),
944            affected_marks: vec![0u32; n],
945            affected_epoch: 1,
946            affected_nodes: Vec::new(),
947        }
948    }
949
950    #[inline]
951    fn reset(&mut self) {
952        self.affected_epoch = self.affected_epoch.wrapping_add(1);
953        if self.affected_epoch == 0 {
954            self.affected_marks.fill(0);
955            self.affected_epoch = 1;
956        }
957        self.affected_nodes.clear();
958        self.incoming_flat.clear();
959    }
960
961    #[inline]
962    fn mark_affected(&mut self, u: usize) {
963        if self.affected_marks[u] != self.affected_epoch {
964            self.affected_marks[u] = self.affected_epoch;
965            self.affected_nodes.push(u);
966            self.incoming_counts[u] = 0;
967        }
968    }
969}
970
971/// Build a Vamana-like graph using a micro-batched practical DiskANN strategy,
972/// with reusable scratch both for per-thread search state and for chunk merge state.
973fn build_vamana_graph<T, D>(
974    vectors: &FlatVectors<T>,
975    max_degree: usize,
976    build_beam_width: usize,
977    alpha: f32,
978    passes: usize,
979    extra_seeds: usize,
980    dist: D,
981    medoid_id: u32,
982) -> Vec<Vec<u32>>
983where
984    T: bytemuck::Pod + Copy + Send + Sync,
985    D: Distance<T> + Copy + Sync,
986{
987    let n = vectors.n;
988    let mut graph = vec![Vec::<u32>::new(); n];
989    // Bootstrap with a random R-out directed graph.
990    {
991        let mut rng = thread_rng();
992        let target = max_degree.min(n.saturating_sub(1));
993
994        for i in 0..n {
995            let mut s = HashSet::with_capacity(target);
996            while s.len() < target {
997                let nb = rng.gen_range(0..n);
998                if nb != i {
999                    s.insert(nb as u32);
1000                }
1001            }
1002            graph[i] = s.into_iter().collect();
1003        }
1004    }
1005
1006    let passes = passes.max(1);
1007    let mut rng = thread_rng();
1008    let slack_limit = ((GRAPH_SLACK_FACTOR * max_degree as f32).ceil() as usize).max(max_degree);
1009
1010    // Reused across all chunks in all passes.
1011    let mut merge_scratch = MergeScratch::new(n);
1012
1013    for pass_idx in 0..passes {
1014        let pass_alpha = if passes == 1 {
1015            alpha
1016        } else if pass_idx == 0 {
1017            1.0
1018        } else {
1019            alpha
1020        };
1021
1022        let mut order: Vec<usize> = (0..n).collect();
1023        order.shuffle(&mut rng);
1024
1025        for chunk in order.chunks(MICRO_BATCH_CHUNK_SIZE) {
1026            let snapshot = &graph;
1027            // Compute new outgoing lists for this chunk in parallel.
1028            let chunk_results: Vec<(usize, Vec<u32>)> = chunk
1029                .par_iter()
1030                .map_init(
1031                    || IncrementalInsertScratch::new(n, build_beam_width, max_degree, extra_seeds),
1032                    |scratch, &u| {
1033                        let bs = &mut scratch.build;
1034                        bs.candidates.clear();
1035
1036                        // Start from current adjacency.
1037                        for &nb in &snapshot[u] {
1038                            let d = dist.eval(vectors.row(u), vectors.row(nb as usize));
1039                            bs.candidates.push((nb, d));
1040                        }
1041
1042                        // Seed list: medoid + distinct random starts.
1043                        bs.seeds.clear();
1044                        bs.seeds.push(medoid_id as usize);
1045
1046                        let mut local_rng = thread_rng();
1047                        while bs.seeds.len() < 1 + extra_seeds {
1048                            let s = local_rng.gen_range(0..n);
1049                            if !bs.seeds.contains(&s) {
1050                                bs.seeds.push(s);
1051                            }
1052                        }
1053
1054                        let seeds_len = bs.seeds.len();
1055                        for si in 0..seeds_len {
1056                            let start = bs.seeds[si];
1057
1058                            greedy_search_visited_collect(
1059                                vectors.row(u),
1060                                vectors,
1061                                snapshot,
1062                                start,
1063                                build_beam_width,
1064                                dist,
1065                                bs,
1066                            );
1067
1068                            for i in 0..bs.visited_ids.len() {
1069                                bs.candidates.push((bs.visited_ids[i], bs.visited_dists[i]));
1070                            }
1071                        }
1072
1073                        dedup_keep_best_by_id_in_place(&mut bs.candidates);
1074
1075                        let pruned = prune_neighbors(
1076                            u,
1077                            &bs.candidates,
1078                            vectors,
1079                            max_degree,
1080                            pass_alpha,
1081                            dist,
1082                        );
1083
1084                        (u, pruned)
1085                    },
1086                )
1087                .collect();
1088
1089            let mut chunk_nodes = Vec::<usize>::with_capacity(chunk_results.len());
1090            let mut chunk_pruned = Vec::<Vec<u32>>::with_capacity(chunk_results.len());
1091
1092            for (u, pruned) in chunk_results {
1093                chunk_nodes.push(u);
1094                chunk_pruned.push(pruned);
1095            }
1096            // Merge chunk back into graph using reusable CSR-style scratch.
1097            merge_chunk_updates_into_graph_reuse(
1098                &mut graph,
1099                &chunk_nodes,
1100                &chunk_pruned,
1101                vectors,
1102                max_degree,
1103                slack_limit,
1104                pass_alpha,
1105                dist,
1106                &mut merge_scratch,
1107            );
1108        }
1109    }
1110
1111    // Final cleanup: enforce bounded degree and deduplication.
1112    graph
1113        .into_par_iter()
1114        .enumerate()
1115        .map(|(u, neigh)| {
1116            if neigh.len() <= max_degree {
1117                return neigh;
1118            }
1119
1120            let mut ids = neigh;
1121            ids.sort_unstable();
1122            ids.dedup();
1123
1124            let pool: Vec<(u32, f32)> = ids
1125                .into_iter()
1126                .filter(|&id| id as usize != u)
1127                .map(|id| (id, dist.eval(vectors.row(u), vectors.row(id as usize))))
1128                .collect();
1129
1130            prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
1131        })
1132        .collect()
1133}
1134
1135/// Build-time greedy search:
1136/// - dense visited marks instead of HashMap/HashSet
1137/// - visited_ids + visited_dists instead of recomputing distances later
1138/// - ordered beams instead of BinaryHeap
1139/// Output is written into `scratch.visited_ids` and `scratch.visited_dists`.
1140fn greedy_search_visited_collect<T, D>(
1141    query: &[T],
1142    vectors: &FlatVectors<T>,
1143    graph: &[Vec<u32>],
1144    start_id: usize,
1145    beam_width: usize,
1146    dist: D,
1147    scratch: &mut BuildScratch,
1148) where
1149    T: bytemuck::Pod + Copy + Send + Sync,
1150    D: Distance<T> + Copy,
1151{
1152    scratch.reset_search();
1153
1154    let start_dist = dist.eval(query, vectors.row(start_id));
1155    let start = Candidate {
1156        dist: start_dist,
1157        id: start_id as u32,
1158    };
1159
1160    scratch.frontier.insert_unbounded(start);
1161    scratch.work.insert_capped(start, beam_width);
1162    scratch.mark_with_dist(start_id, start_dist);
1163
1164    while !scratch.frontier.is_empty() {
1165        let best = scratch.frontier.best().unwrap();
1166        if scratch.work.len() >= beam_width {
1167            if let Some(worst) = scratch.work.worst() {
1168                if best.dist >= worst.dist {
1169                    break;
1170                }
1171            }
1172        }
1173
1174        let cur = scratch.frontier.pop_best().unwrap();
1175
1176        for &nb in &graph[cur.id as usize] {
1177            let nb_usize = nb as usize;
1178            if scratch.is_marked(nb_usize) {
1179                continue;
1180            }
1181
1182            let d = dist.eval(query, vectors.row(nb_usize));
1183            scratch.mark_with_dist(nb_usize, d);
1184
1185            let cand = Candidate { dist: d, id: nb };
1186
1187            if scratch.work.len() < beam_width {
1188                scratch.work.insert_unbounded(cand);
1189                scratch.frontier.insert_unbounded(cand);
1190            } else if let Some(worst) = scratch.work.worst() {
1191                if d < worst.dist {
1192                    scratch.work.insert_capped(cand, beam_width);
1193                    scratch.frontier.insert_unbounded(cand);
1194                }
1195            }
1196        }
1197    }
1198}
1199
1200/// α-pruning with nearest-neighbor backfill.
1201fn prune_neighbors<T, D>(
1202    node_id: usize,
1203    candidates: &[(u32, f32)],
1204    vectors: &FlatVectors<T>,
1205    max_degree: usize,
1206    alpha: f32,
1207    dist: D,
1208) -> Vec<u32>
1209where
1210    T: bytemuck::Pod + Copy + Send + Sync,
1211    D: Distance<T> + Copy,
1212{
1213    if candidates.is_empty() || max_degree == 0 {
1214        return Vec::new();
1215    }
1216
1217    // Sort by distance from node_id, nearest first.
1218    let mut sorted = candidates.to_vec();
1219    sorted.sort_by(|a, b| a.1.total_cmp(&b.1));
1220
1221    // Remove self and duplicate ids while keeping the nearest occurrence.
1222    let mut uniq = Vec::<(u32, f32)>::with_capacity(sorted.len());
1223    let mut last_id: Option<u32> = None;
1224    for &(cand_id, cand_dist) in &sorted {
1225        if cand_id as usize == node_id {
1226            continue;
1227        }
1228        if last_id == Some(cand_id) {
1229            continue;
1230        }
1231        uniq.push((cand_id, cand_dist));
1232        last_id = Some(cand_id);
1233    }
1234
1235    if uniq.is_empty() {
1236        return Vec::new();
1237    }
1238
1239    let mut pruned = Vec::<u32>::with_capacity(max_degree);
1240
1241    // Phase 1: robust α-pruning
1242    for &(cand_id, cand_dist_to_node) in &uniq {
1243        let mut occluded = false;
1244
1245        for &sel_id in &pruned {
1246            let d_cand_sel = dist.eval(
1247                vectors.row(cand_id as usize),
1248                vectors.row(sel_id as usize),
1249            );
1250
1251            if alpha * d_cand_sel <= cand_dist_to_node {
1252                occluded = true;
1253                break;
1254            }
1255        }
1256
1257        if !occluded {
1258            pruned.push(cand_id);
1259            if pruned.len() >= max_degree {
1260                return pruned;
1261            }
1262        }
1263    }
1264
1265    // Phase 2: backfill nearest remaining candidates
1266    if pruned.len() < max_degree {
1267        for &(cand_id, _) in &uniq {
1268            if pruned.contains(&cand_id) {
1269                continue;
1270            }
1271            pruned.push(cand_id);
1272            if pruned.len() >= max_degree {
1273                break;
1274            }
1275        }
1276    }
1277
1278    pruned
1279}
1280
1281#[cfg(test)]
1282mod tests {
1283    use super::*;
1284    use anndists::dist::{DistCosine, DistL2};
1285    use rand::Rng;
1286    use std::fs;
1287
1288    fn euclid(a: &[f32], b: &[f32]) -> f32 {
1289        a.iter()
1290            .zip(b)
1291            .map(|(x, y)| (x - y) * (x - y))
1292            .sum::<f32>()
1293            .sqrt()
1294    }
1295
1296    #[test]
1297    fn test_small_index_l2() {
1298        let path = "test_small_l2.db";
1299        let _ = fs::remove_file(path);
1300
1301        let vectors = vec![
1302            vec![0.0, 0.0],
1303            vec![1.0, 0.0],
1304            vec![0.0, 1.0],
1305            vec![1.0, 1.0],
1306            vec![0.5, 0.5],
1307        ];
1308
1309        let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1310
1311        let q = vec![0.1, 0.1];
1312        let nns = index.search(&q, 3, 8);
1313        assert_eq!(nns.len(), 3);
1314
1315        let v = index.get_vector(nns[0] as usize);
1316        assert!(euclid(&q, &v) < 1.0);
1317
1318        let _ = fs::remove_file(path);
1319    }
1320
1321    #[test]
1322    fn test_cosine() {
1323        let path = "test_cosine.db";
1324        let _ = fs::remove_file(path);
1325
1326        let vectors = vec![
1327            vec![1.0, 0.0, 0.0],
1328            vec![0.0, 1.0, 0.0],
1329            vec![0.0, 0.0, 1.0],
1330            vec![1.0, 1.0, 0.0],
1331            vec![1.0, 0.0, 1.0],
1332        ];
1333
1334        let index =
1335            DiskANN::<f32, DistCosine>::build_index_default(&vectors, DistCosine, path).unwrap();
1336
1337        let q = vec![2.0, 0.0, 0.0];
1338        let nns = index.search(&q, 2, 8);
1339        assert_eq!(nns.len(), 2);
1340
1341        let v = index.get_vector(nns[0] as usize);
1342        let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
1343        let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1344        let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
1345        let cos = dot / (n1 * n2);
1346        assert!(cos > 0.7);
1347
1348        let _ = fs::remove_file(path);
1349    }
1350
1351    #[test]
1352    fn test_persistence_and_open() {
1353        let path = "test_persist.db";
1354        let _ = fs::remove_file(path);
1355
1356        let vectors = vec![
1357            vec![0.0, 0.0],
1358            vec![1.0, 0.0],
1359            vec![0.0, 1.0],
1360            vec![1.0, 1.0],
1361        ];
1362
1363        {
1364            let _idx =
1365                DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1366        }
1367
1368        let idx2 = DiskANN::<f32, DistL2>::open_index_default_metric(path).unwrap();
1369        assert_eq!(idx2.num_vectors, 4);
1370        assert_eq!(idx2.dim, 2);
1371
1372        let q = vec![0.9, 0.9];
1373        let res = idx2.search(&q, 2, 8);
1374        assert_eq!(res[0], 3);
1375
1376        let _ = fs::remove_file(path);
1377    }
1378
1379    #[test]
1380    fn test_grid_connectivity() {
1381        let path = "test_grid.db";
1382        let _ = fs::remove_file(path);
1383
1384        let mut vectors = Vec::new();
1385        for i in 0..5 {
1386            for j in 0..5 {
1387                vectors.push(vec![i as f32, j as f32]);
1388            }
1389        }
1390
1391        let index = DiskANN::<f32, DistL2>::build_index_with_params(
1392            &vectors,
1393            DistL2,
1394            path,
1395            DiskAnnParams {
1396                max_degree: 4,
1397                build_beam_width: 64,
1398                alpha: 1.5,
1399                passes: DISKANN_DEFAULT_PASSES,
1400                extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1401            },
1402        )
1403        .unwrap();
1404
1405        for target in 0..vectors.len() {
1406            let q = &vectors[target];
1407            let nns = index.search(q, 10, 32);
1408            if !nns.contains(&(target as u32)) {
1409                let v = index.get_vector(nns[0] as usize);
1410                assert!(euclid(q, &v) < 2.0);
1411            }
1412            for &nb in nns.iter().take(5) {
1413                let v = index.get_vector(nb as usize);
1414                assert!(euclid(q, &v) < 5.0);
1415            }
1416        }
1417
1418        let _ = fs::remove_file(path);
1419    }
1420
1421    #[test]
1422    fn test_medium_random() {
1423        let path = "test_medium.db";
1424        let _ = fs::remove_file(path);
1425
1426        let n = 200usize;
1427        let d = 32usize;
1428        let mut rng = rand::thread_rng();
1429        let vectors: Vec<Vec<f32>> = (0..n)
1430            .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
1431            .collect();
1432
1433        let index = DiskANN::<f32, DistL2>::build_index_with_params(
1434            &vectors,
1435            DistL2,
1436            path,
1437            DiskAnnParams {
1438                max_degree: 32,
1439                build_beam_width: 128,
1440                alpha: 1.2,
1441                passes: DISKANN_DEFAULT_PASSES,
1442                extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1443            },
1444        )
1445        .unwrap();
1446
1447        let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1448        let res = index.search(&q, 10, 64);
1449        assert_eq!(res.len(), 10);
1450
1451        let dists: Vec<f32> = res
1452            .iter()
1453            .map(|&id| {
1454                let v = index.get_vector(id as usize);
1455                euclid(&q, &v)
1456            })
1457            .collect();
1458        let mut sorted = dists.clone();
1459        sorted.sort_by(|a, b| a.total_cmp(b));
1460        assert_eq!(dists, sorted);
1461
1462        let _ = fs::remove_file(path);
1463    }
1464}