Skip to main content

rust_diskann/
lib.rs

1//! # DiskAnn (generic over `anndists::Distance<T>`)
2//!
3//! An on-disk DiskANN library that:
4//! - Builds a Vamana-style graph (greedy + α-pruning) in memory
5//! - Writes vectors + fixed-degree adjacency to a single file
6//! - Memory-maps the file for low-overhead reads
7//! - Is **generic over any Distance<T>** from `anndists` (e.g. L2 on `f32`, Cosine on `f32`,
8//!   Hamming on `u64`, …)
9//!
10//! ## Example (f32 + L2)
11//! ```no_run
12//! use anndists::dist::DistL2;
13//! use rust_diskann::{DiskANN, DiskAnnParams};
14//!
15//! let vectors: Vec<Vec<f32>> = vec![vec![0.0; 128]; 1000];
16//! let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, "index.db").unwrap();
17//!
18//! let q = vec![0.0; 128];
19//! let nns = index.search(&q, 10, 64);
20//! ```
21//!
22//! ## Example (u64 + Hamming)
23//! ```no_run
24//! use anndists::dist::DistHamming;
25//! use rust_diskann::{DiskANN, DiskAnnParams};
26//! let index: Vec<Vec<u64>> = vec![vec![0u64; 128]; 1000];
27//! let idx = DiskANN::<u64, DistHamming>::build_index_default(&index, DistHamming, "mh.db").unwrap();
28//! let q = vec![0u64; 128];
29//! let _ = idx.search(&q, 10, 64);
30//! ```
31//!
32//! ## File Layout
33//! [ metadata_len:u64 ][ metadata (bincode) ][ padding up to vectors_offset ]
34//! [ vectors (num * dim * T) ][ adjacency (num * max_degree * u32) ]
35//!
36//! `vectors_offset` is a fixed 1 MiB gap by default.
37
38use anndists::prelude::Distance;
39use memmap2::Mmap;
40use rand::{prelude::*, thread_rng};
41use rayon::prelude::*;
42use serde::{Deserialize, Serialize};
43use std::cmp::{Ordering, Reverse};
44use std::collections::{BinaryHeap, HashSet};
45use std::fs::OpenOptions;
46use std::io::{Read, Seek, SeekFrom, Write};
47use std::marker::PhantomData;
48use thiserror::Error;
49
50/// Padding sentinel for adjacency slots (avoid colliding with node 0).
51const PAD_U32: u32 = u32::MAX;
52
53/// Defaults for in-memory DiskANN builds
54pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
55pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
56pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
57/// Default number of refinement passes during graph build
58pub const DISKANN_DEFAULT_PASSES: usize = 1;
59/// Default number of extra random seeds per node per pass during graph build
60pub const DISKANN_DEFAULT_EXTRA_SEEDS: usize = 1;
61
62/// Practical DiskANN-style slack before reverse-neighbor re-pruning.
63/// Legacy C++ DiskANN allows reverse lists to grow to about GRAPH_SLACK_FACTOR * R
64/// before triggering prune, instead of pruning immediately at R.
65const GRAPH_SLACK_FACTOR: f32 = 1.3;
66
67
68/// Number of nodes processed together in one micro-batch during graph build.
69///
70/// Smaller values:
71/// - are closer to true incremental insertion
72/// - usually improve faithfulness to practical Vamana behavior
73/// - but are slower
74///
75/// Larger values:
76/// - are faster
77/// - but behave more like a batched graph rebuild
78/// Recommended starting values:
79/// - 128 for better quality
80/// - 256 for a balanced tradeoff
81/// - 512 for faster builds
82
83const MICRO_BATCH_CHUNK_SIZE: usize = 256;
84
85/// Optional bag of knobs if you want to override just a few.
86#[derive(Clone, Copy, Debug)]
87pub struct DiskAnnParams {
88    pub max_degree: usize,
89    pub build_beam_width: usize,
90    pub alpha: f32,
91    /// Number of refinement passes over the graph (>=1).
92    pub passes: usize,
93    /// Extra random seeds per node during each pass (>=0).
94    pub extra_seeds: usize,
95}
96
97impl Default for DiskAnnParams {
98    fn default() -> Self {
99        Self {
100            max_degree: DISKANN_DEFAULT_MAX_DEGREE,
101            build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
102            alpha: DISKANN_DEFAULT_ALPHA,
103            passes: DISKANN_DEFAULT_PASSES,
104            extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
105        }
106    }
107}
108
109/// Custom error type for DiskAnn operations
110#[derive(Debug, Error)]
111pub enum DiskAnnError {
112    /// Represents I/O errors during file operations
113    #[error("I/O error: {0}")]
114    Io(#[from] std::io::Error),
115
116    /// Represents serialization/deserialization errors
117    #[error("Serialization error: {0}")]
118    Bincode(#[from] bincode::Error),
119
120    /// Represents index-specific errors
121    #[error("Index error: {0}")]
122    IndexError(String),
123}
124
125/// Internal metadata structure stored in the index file
126#[derive(Serialize, Deserialize, Debug)]
127struct Metadata {
128    dim: usize,
129    num_vectors: usize,
130    max_degree: usize,
131    medoid_id: u32,
132    vectors_offset: u64,
133    adjacency_offset: u64,
134    elem_size: u8,
135    distance_name: String,
136}
137
138/// Candidate for search/frontier queues
139#[derive(Clone, Copy, Debug)]
140struct Candidate {
141    dist: f32,
142    id: u32,
143}
144impl PartialEq for Candidate {
145    fn eq(&self, other: &Self) -> bool {
146        self.id == other.id && self.dist.to_bits() == other.dist.to_bits()
147    }
148}
149impl Eq for Candidate {}
150impl PartialOrd for Candidate {
151    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
152        Some(
153            self.dist
154                .total_cmp(&other.dist)
155                .then_with(|| self.id.cmp(&other.id)),
156        )
157    }
158}
159impl Ord for Candidate {
160    fn cmp(&self, other: &Self) -> Ordering {
161        self.partial_cmp(other).unwrap_or(Ordering::Equal)
162    }
163}
164
165/// Flat contiguous matrix used during build to improve cache locality.
166///
167/// Rows are stored consecutively in `data`, row-major.
168#[derive(Clone, Debug)]
169struct FlatVectors<T> {
170    data: Vec<T>,
171    dim: usize,
172    n: usize,
173}
174
175impl<T: Copy> FlatVectors<T> {
176    fn from_vecs(vectors: &[Vec<T>]) -> Result<Self, DiskAnnError> {
177        if vectors.is_empty() {
178            return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
179        }
180        let dim = vectors[0].len();
181        for (i, v) in vectors.iter().enumerate() {
182            if v.len() != dim {
183                return Err(DiskAnnError::IndexError(format!(
184                    "Vector {} has dimension {} but expected {}",
185                    i,
186                    v.len(),
187                    dim
188                )));
189            }
190        }
191
192        let n = vectors.len();
193        let mut data = Vec::with_capacity(n * dim);
194        for v in vectors {
195            data.extend_from_slice(v);
196        }
197
198        Ok(Self { data, dim, n })
199    }
200
201    #[inline]
202    fn row(&self, idx: usize) -> &[T] {
203        let start = idx * self.dim;
204        let end = start + self.dim;
205        &self.data[start..end]
206    }
207}
208
209/// Small ordered beam structure used only during build-time greedy search.
210///
211/// It keeps elements in **descending** distance order:
212/// - index 0 is the worst element
213/// - last element is the best element
214///
215/// This makes:
216/// - `best()` cheap via `last()`
217/// - `worst()` cheap via `first()`
218/// - capped beam maintenance simple
219#[derive(Default, Debug)]
220struct OrderedBeam {
221    items: Vec<Candidate>,
222}
223
224impl OrderedBeam {
225    #[inline]
226    fn clear(&mut self) {
227        self.items.clear();
228    }
229
230    #[inline]
231    fn len(&self) -> usize {
232        self.items.len()
233    }
234
235    #[inline]
236    fn is_empty(&self) -> bool {
237        self.items.is_empty()
238    }
239
240    #[inline]
241    fn best(&self) -> Option<Candidate> {
242        self.items.last().copied()
243    }
244
245    #[inline]
246    fn worst(&self) -> Option<Candidate> {
247        self.items.first().copied()
248    }
249
250    #[inline]
251    fn pop_best(&mut self) -> Option<Candidate> {
252        self.items.pop()
253    }
254
255    #[inline]
256    fn reserve(&mut self, cap: usize) {
257        if self.items.capacity() < cap {
258            self.items.reserve(cap - self.items.capacity());
259        }
260    }
261
262    #[inline]
263    fn insert_unbounded(&mut self, cand: Candidate) {
264        let pos = self.items.partition_point(|x| {
265            x.dist > cand.dist || (x.dist.to_bits() == cand.dist.to_bits() && x.id > cand.id)
266        });
267        self.items.insert(pos, cand);
268    }
269
270    #[inline]
271    fn insert_capped(&mut self, cand: Candidate, cap: usize) {
272        if cap == 0 {
273            return;
274        }
275
276        if self.items.len() < cap {
277            self.insert_unbounded(cand);
278            return;
279        }
280
281        // Since items[0] is the worst, only insert if the new candidate is better.
282        let worst = self.items[0];
283        if cand.dist >= worst.dist {
284            return;
285        }
286
287        self.insert_unbounded(cand);
288
289        if self.items.len() > cap {
290            self.items.remove(0);
291        }
292    }
293}
294
295/// Reusable scratch buffers for build-time greedy search.
296/// One instance is created per Rayon worker via `map_init`, so allocations are reused
297/// across many nodes in the build.
298#[derive(Debug)]
299struct BuildScratch {
300    marks: Vec<u32>,
301    epoch: u32,
302
303    visited_ids: Vec<u32>,
304    visited_dists: Vec<f32>,
305
306    frontier: OrderedBeam,
307    work: OrderedBeam,
308
309    seeds: Vec<usize>,
310    candidates: Vec<(u32, f32)>,
311}
312
313impl BuildScratch {
314    fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
315        Self {
316            marks: vec![0u32; n],
317            epoch: 1,
318            visited_ids: Vec::with_capacity(beam_width * 4),
319            visited_dists: Vec::with_capacity(beam_width * 4),
320            frontier: {
321                let mut b = OrderedBeam::default();
322                b.reserve(beam_width * 2);
323                b
324            },
325            work: {
326                let mut b = OrderedBeam::default();
327                b.reserve(beam_width * 2);
328                b
329            },
330            seeds: Vec::with_capacity(1 + extra_seeds),
331            candidates: Vec::with_capacity(beam_width * (4 + extra_seeds) + max_degree * 2),
332        }
333    }
334
335    #[inline]
336    fn reset_search(&mut self) {
337        self.epoch = self.epoch.wrapping_add(1);
338        if self.epoch == 0 {
339            self.marks.fill(0);
340            self.epoch = 1;
341        }
342        self.visited_ids.clear();
343        self.visited_dists.clear();
344        self.frontier.clear();
345        self.work.clear();
346    }
347
348    #[inline]
349    fn is_marked(&self, idx: usize) -> bool {
350        self.marks[idx] == self.epoch
351    }
352
353    #[inline]
354    fn mark_with_dist(&mut self, idx: usize, dist: f32) {
355        self.marks[idx] = self.epoch;
356        self.visited_ids.push(idx as u32);
357        self.visited_dists.push(dist);
358    }
359}
360
361#[derive(Debug)]
362struct IncrementalInsertScratch {
363    build: BuildScratch,
364}
365
366impl IncrementalInsertScratch {
367    fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
368        Self {
369            build: BuildScratch::new(n, beam_width, max_degree, extra_seeds),
370        }
371    }
372}
373
374/// Main struct representing a DiskANN index (generic over vector element `T` and distance `D`)
375pub struct DiskANN<T, D>
376where
377    T: bytemuck::Pod + Copy + Send + Sync + 'static,
378    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
379{
380    /// Dimensionality of vectors in the index
381    pub dim: usize,
382    /// Number of vectors in the index
383    pub num_vectors: usize,
384    /// Maximum number of edges per node
385    pub max_degree: usize,
386    /// Informational: type name of the distance (from metadata)
387    pub distance_name: String,
388
389    /// ID of the medoid (used as entry point)
390    medoid_id: u32,
391    // Offsets
392    vectors_offset: u64,
393    adjacency_offset: u64,
394
395    /// Memory-mapped file
396    mmap: Mmap,
397
398    /// The distance strategy
399    dist: D,
400
401    /// keep `T` in the type so the compiler knows about it
402    _phantom: PhantomData<T>,
403}
404
405// constructors
406
407impl<T, D> DiskANN<T, D>
408where
409    T: bytemuck::Pod + Copy + Send + Sync + 'static,
410    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
411{
412    /// Build with default parameters: (M=64, L=128, alpha=1.2, passes=2, extra_seeds=2).
413    pub fn build_index_default(
414        vectors: &[Vec<T>],
415        dist: D,
416        file_path: &str,
417    ) -> Result<Self, DiskAnnError> {
418        Self::build_index(
419            vectors,
420            DISKANN_DEFAULT_MAX_DEGREE,
421            DISKANN_DEFAULT_BUILD_BEAM,
422            DISKANN_DEFAULT_ALPHA,
423            DISKANN_DEFAULT_PASSES,
424            DISKANN_DEFAULT_EXTRA_SEEDS,
425            dist,
426            file_path,
427        )
428    }
429
430    /// Build with a `DiskAnnParams` bundle.
431    pub fn build_index_with_params(
432        vectors: &[Vec<T>],
433        dist: D,
434        file_path: &str,
435        p: DiskAnnParams,
436    ) -> Result<Self, DiskAnnError> {
437        Self::build_index(
438            vectors,
439            p.max_degree,
440            p.build_beam_width,
441            p.alpha,
442            p.passes,
443            p.extra_seeds,
444            dist,
445            file_path,
446        )
447    }
448
449    /// Opens an existing index file, supplying the distance strategy explicitly.
450    pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
451        let mut file = OpenOptions::new().read(true).write(false).open(path)?;
452
453        // Read metadata length
454        let mut buf8 = [0u8; 8];
455        file.seek(SeekFrom::Start(0))?;
456        file.read_exact(&mut buf8)?;
457        let md_len = u64::from_le_bytes(buf8);
458
459        // Read metadata
460        let mut md_bytes = vec![0u8; md_len as usize];
461        file.read_exact(&mut md_bytes)?;
462        let metadata: Metadata = bincode::deserialize(&md_bytes)?;
463
464        let mmap = unsafe { memmap2::Mmap::map(&file)? };
465
466        // Validate element size vs T
467        let want = std::mem::size_of::<T>() as u8;
468        if metadata.elem_size != want {
469            return Err(DiskAnnError::IndexError(format!(
470                "element size mismatch: file has {}B, T is {}B",
471                metadata.elem_size, want
472            )));
473        }
474
475        // Optional sanity/logging: warn if type differs from recorded name
476        let expected = std::any::type_name::<D>();
477        if metadata.distance_name != expected {
478            eprintln!(
479                "Warning: index recorded distance `{}` but you opened with `{}`",
480                metadata.distance_name, expected
481            );
482        }
483
484        Ok(Self {
485            dim: metadata.dim,
486            num_vectors: metadata.num_vectors,
487            max_degree: metadata.max_degree,
488            distance_name: metadata.distance_name,
489            medoid_id: metadata.medoid_id,
490            vectors_offset: metadata.vectors_offset,
491            adjacency_offset: metadata.adjacency_offset,
492            mmap,
493            dist,
494            _phantom: PhantomData,
495        })
496    }
497}
498
499/// Extra sugar when your distance type implements `Default`.
500impl<T, D> DiskANN<T, D>
501where
502    T: bytemuck::Pod + Copy + Send + Sync + 'static,
503    D: Distance<T> + Default + Send + Sync + Copy + Clone + 'static,
504{
505    /// Build with default params **and** `D::default()` metric.
506    pub fn build_index_default_metric(
507        vectors: &[Vec<T>],
508        file_path: &str,
509    ) -> Result<Self, DiskAnnError> {
510        Self::build_index_default(vectors, D::default(), file_path)
511    }
512
513    /// Open an index using `D::default()` as the distance (matches what you built with).
514    pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
515        Self::open_index_with(path, D::default())
516    }
517}
518
519impl<T, D> DiskANN<T, D>
520where
521    T: bytemuck::Pod + Copy + Send + Sync + 'static,
522    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
523{
524    /// Builds a new index from provided vectors
525    ///
526    /// # Arguments
527    /// * `vectors` - The vectors to index (slice of Vec<T>)
528    /// * `max_degree` - Maximum edges per node (M ~ 24-64+)
529    /// * `build_beam_width` - Construction L (e.g., 128-400)
530    /// * `alpha` - Pruning parameter (1.2–2.0)
531    /// * `passes` - Refinement passes over the graph (>=1)
532    /// * `extra_seeds` - Extra random seeds per node per pass (>=0)
533    /// * `dist` - Any `anndists::Distance<T>`
534    /// * `file_path` - Path of index file
535    pub fn build_index(
536        vectors: &[Vec<T>],
537        max_degree: usize,
538        build_beam_width: usize,
539        alpha: f32,
540        passes: usize,
541        extra_seeds: usize,
542        dist: D,
543        file_path: &str,
544    ) -> Result<Self, DiskAnnError> {
545        let flat = FlatVectors::from_vecs(vectors)?;
546
547        let num_vectors = flat.n;
548        let dim = flat.dim;
549
550        let mut file = OpenOptions::new()
551            .create(true)
552            .write(true)
553            .read(true)
554            .truncate(true)
555            .open(file_path)?;
556
557        // Reserve space for metadata (we'll write it after data)
558        let vectors_offset = 1024 * 1024;
559        assert_eq!(
560            (vectors_offset as usize) % std::mem::align_of::<T>(),
561            0,
562            "vectors_offset must be aligned for T"
563        );
564
565        let elem_sz = std::mem::size_of::<T>() as u64;
566        let total_vector_bytes = (num_vectors as u64) * (dim as u64) * elem_sz;
567
568        // Write vectors contiguous
569        file.seek(SeekFrom::Start(vectors_offset as u64))?;
570        file.write_all(bytemuck::cast_slice::<T, u8>(&flat.data))?;
571
572        // Compute medoid using flat storage
573        let medoid_id = calculate_medoid(&flat, dist);
574
575        // Build graph
576        let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
577        let graph = build_vamana_graph(
578            &flat,
579            max_degree,
580            build_beam_width,
581            alpha,
582            passes,
583            extra_seeds,
584            dist,
585            medoid_id as u32,
586        );
587
588        // Write adjacency lists
589        file.seek(SeekFrom::Start(adjacency_offset))?;
590        for neighbors in &graph {
591            let mut padded = neighbors.clone();
592            padded.resize(max_degree, PAD_U32);
593            let bytes = bytemuck::cast_slice::<u32, u8>(&padded);
594            file.write_all(bytes)?;
595        }
596
597        // Write metadata
598        let metadata = Metadata {
599            dim,
600            num_vectors,
601            max_degree,
602            medoid_id: medoid_id as u32,
603            vectors_offset: vectors_offset as u64,
604            adjacency_offset,
605            elem_size: std::mem::size_of::<T>() as u8,
606            distance_name: std::any::type_name::<D>().to_string(),
607        };
608
609        let md_bytes = bincode::serialize(&metadata)?;
610        file.seek(SeekFrom::Start(0))?;
611        let md_len = md_bytes.len() as u64;
612        file.write_all(&md_len.to_le_bytes())?;
613        file.write_all(&md_bytes)?;
614        file.sync_all()?;
615
616        // Memory map the file
617        let mmap = unsafe { memmap2::Mmap::map(&file)? };
618
619        Ok(Self {
620            dim,
621            num_vectors,
622            max_degree,
623            distance_name: metadata.distance_name,
624            medoid_id: metadata.medoid_id,
625            vectors_offset: metadata.vectors_offset,
626            adjacency_offset: metadata.adjacency_offset,
627            mmap,
628            dist,
629            _phantom: PhantomData,
630        })
631    }
632
633    /// Searches the index for nearest neighbors using a best-first beam search.
634    /// Termination rule: continue while the best frontier can still improve the worst in working set.
635    pub fn search_with_dists(&self, query: &[T], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
636        assert_eq!(
637            query.len(),
638            self.dim,
639            "Query dim {} != index dim {}",
640            query.len(),
641            self.dim
642        );
643
644        let mut visited = HashSet::new();
645        let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
646        let mut w: BinaryHeap<Candidate> = BinaryHeap::new();
647
648        let start_dist = self.distance_to(query, self.medoid_id as usize);
649        let start = Candidate {
650            dist: start_dist,
651            id: self.medoid_id,
652        };
653        frontier.push(Reverse(start));
654        w.push(start);
655        visited.insert(self.medoid_id);
656
657        while let Some(Reverse(best)) = frontier.peek().copied() {
658            if w.len() >= beam_width {
659                if let Some(worst) = w.peek() {
660                    if best.dist >= worst.dist {
661                        break;
662                    }
663                }
664            }
665            let Reverse(current) = frontier.pop().unwrap();
666
667            for &nb in self.get_neighbors(current.id) {
668                if nb == PAD_U32 {
669                    continue;
670                }
671                if !visited.insert(nb) {
672                    continue;
673                }
674
675                let d = self.distance_to(query, nb as usize);
676                let cand = Candidate { dist: d, id: nb };
677
678                if w.len() < beam_width {
679                    w.push(cand);
680                    frontier.push(Reverse(cand));
681                } else if d < w.peek().unwrap().dist {
682                    w.pop();
683                    w.push(cand);
684                    frontier.push(Reverse(cand));
685                }
686            }
687        }
688
689        let mut results: Vec<_> = w.into_vec();
690        results.sort_by(|a, b| a.dist.total_cmp(&b.dist));
691        results.truncate(k);
692        results.into_iter().map(|c| (c.id, c.dist)).collect()
693    }
694
695    /// search but only return neighbor ids
696    pub fn search(&self, query: &[T], k: usize, beam_width: usize) -> Vec<u32> {
697        self.search_with_dists(query, k, beam_width)
698            .into_iter()
699            .map(|(id, _dist)| id)
700            .collect()
701    }
702
703    /// Gets the neighbors of a node from the (fixed-degree) adjacency region
704    fn get_neighbors(&self, node_id: u32) -> &[u32] {
705        let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
706        let start = offset as usize;
707        let end = start + (self.max_degree * 4);
708        let bytes = &self.mmap[start..end];
709        bytemuck::cast_slice(bytes)
710    }
711
712    /// Computes distance between `query` and vector `idx`
713    fn distance_to(&self, query: &[T], idx: usize) -> f32 {
714        let elem_sz = std::mem::size_of::<T>();
715        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
716        let start = offset as usize;
717        let end = start + (self.dim * elem_sz);
718        let bytes = &self.mmap[start..end];
719        let vector: &[T] = bytemuck::cast_slice(bytes);
720        self.dist.eval(query, vector)
721    }
722
723    /// Gets a vector from the index
724    pub fn get_vector(&self, idx: usize) -> Vec<T> {
725        let elem_sz = std::mem::size_of::<T>();
726        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
727        let start = offset as usize;
728        let end = start + (self.dim * elem_sz);
729        let bytes = &self.mmap[start..end];
730        let vector: &[T] = bytemuck::cast_slice(bytes);
731        vector.to_vec()
732    }
733}
734
735/// Calculates the medoid using flat contiguous storage.
736fn calculate_medoid<T, D>(vectors: &FlatVectors<T>, dist: D) -> usize
737where
738    T: bytemuck::Pod + Copy + Send + Sync,
739    D: Distance<T> + Copy + Sync,
740{
741    let n = vectors.n;
742    let k = 8.min(n);
743    let mut rng = thread_rng();
744    let pivots: Vec<usize> = (0..k).map(|_| rng.gen_range(0..n)).collect();
745
746    let (best_idx, _best_score) = (0..n)
747        .into_par_iter()
748        .map(|i| {
749            let vi = vectors.row(i);
750            let score: f32 = pivots.iter().map(|&p| dist.eval(vi, vectors.row(p))).sum();
751            (i, score)
752        })
753        .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
754
755    best_idx
756}
757
758fn dedup_keep_best_by_id_in_place(cands: &mut Vec<(u32, f32)>) {
759    if cands.is_empty() {
760        return;
761    }
762
763    cands.sort_by(|a, b| {
764        a.0.cmp(&b.0)
765            .then_with(|| a.1.total_cmp(&b.1))
766    });
767
768    let mut write = 0usize;
769    for read in 0..cands.len() {
770        if write == 0 || cands[read].0 != cands[write - 1].0 {
771            cands[write] = cands[read];
772            write += 1;
773        }
774    }
775    cands.truncate(write);
776}
777
778/// Merge one micro-batch of newly computed outgoing neighbor lists back into the graph.
779/// Mark all nodes in the chunk as affected.
780/// Count reverse-edge insertions implied by the chunk.
781/// Build a CSR-style flat incoming buffer.
782/// Commit chunk outgoing lists.
783/// Re-prune only the affected nodes.
784///
785/// This keeps the practical micro-batched DiskANN behavior while significantly
786/// reducing allocator pressure on large builds.
787fn merge_chunk_updates_into_graph_reuse<T, D>(
788    graph: &mut [Vec<u32>],
789    chunk_nodes: &[usize],
790    chunk_pruned: &[Vec<u32>],
791    vectors: &FlatVectors<T>,
792    max_degree: usize,
793    _slack_limit: usize,
794    alpha: f32,
795    dist: D,
796    merge: &mut MergeScratch,
797) where
798    T: bytemuck::Pod + Copy + Send + Sync,
799    D: Distance<T> + Copy + Sync,
800{
801    merge.reset();
802    // mark chunk nodes as affected.
803    for &u in chunk_nodes {
804        merge.mark_affected(u);
805    }
806
807    // count reverse-edge insertions, and mark touched destinations.
808    let mut total_incoming = 0usize;
809
810    for (local_idx, &u) in chunk_nodes.iter().enumerate() {
811        for &dst in &chunk_pruned[local_idx] {
812            let dst_usize = dst as usize;
813            if dst_usize == u {
814                continue;
815            }
816
817            merge.mark_affected(dst_usize);
818            merge.incoming_counts[dst_usize] += 1;
819            total_incoming += 1;
820        }
821    }
822
823    // build CSR offsets only for affected nodes.
824    merge.affected_nodes.sort_unstable();
825
826    let mut running = 0usize;
827    for &u in &merge.affected_nodes {
828        merge.incoming_offsets[u] = running;
829        running += merge.incoming_counts[u];
830        merge.incoming_offsets[u + 1] = running;
831    }
832
833    merge.incoming_flat.resize(total_incoming, PAD_U32);
834
835    // For affected nodes only, initialize write cursors.
836    for &u in &merge.affected_nodes {
837        merge.incoming_write[u] = merge.incoming_offsets[u];
838    }
839
840    // fill CSR incoming buffer.
841    // incoming_flat[offset[u]..offset[u+1]] stores reverse insertions targeting u.
842    for (local_idx, &u) in chunk_nodes.iter().enumerate() {
843        for &dst in &chunk_pruned[local_idx] {
844            let dst_usize = dst as usize;
845            if dst_usize == u {
846                continue;
847            }
848
849            let pos = merge.incoming_write[dst_usize];
850            merge.incoming_flat[pos] = u as u32;
851            merge.incoming_write[dst_usize] += 1;
852        }
853    }
854
855    // commit the chunk's outgoing lists first.
856    for (local_idx, &u) in chunk_nodes.iter().enumerate() {
857        graph[u] = chunk_pruned[local_idx].clone();
858    }
859    // re-prune only affected nodes.
860    let affected = merge.affected_nodes.clone();
861
862    let updated_pairs: Vec<(usize, Vec<u32>)> = affected
863        .into_par_iter()
864        .map(|u| {
865            let start = merge.incoming_offsets[u];
866            let end = merge.incoming_offsets[u + 1];
867
868            let mut ids: Vec<u32> = Vec::with_capacity(graph[u].len() + (end - start));
869
870            // Current adjacency after chunk commit.
871            ids.extend_from_slice(&graph[u]);
872
873            // Reverse insertions from this chunk.
874            if start < end {
875                ids.extend_from_slice(&merge.incoming_flat[start..end]);
876            }
877
878            // Remove self-loops.
879            ids.retain(|&id| id != PAD_U32 && id as usize != u);
880
881            // Deduplicate ids before scoring.
882            ids.sort_unstable();
883            ids.dedup();
884
885            if ids.is_empty() {
886                return (u, Vec::new());
887            }
888
889            let mut pool = Vec::<(u32, f32)>::with_capacity(ids.len());
890            for id in ids {
891                let d = dist.eval(vectors.row(u), vectors.row(id as usize));
892                pool.push((id, d));
893            }
894
895            let pruned = prune_neighbors(u, &pool, vectors, max_degree, alpha, dist);
896            (u, pruned)
897        })
898        .collect();
899
900    for (u, neigh) in updated_pairs {
901        graph[u] = neigh;
902    }
903    // Optional cleanup for next chunk: clear only touched counts/offset tails.
904    for &u in &merge.affected_nodes {
905        merge.incoming_counts[u] = 0;
906        merge.incoming_offsets[u + 1] = 0;
907    }
908}
909
910/// Reusable scratch buffers for micro-batch merge.
911/// This avoids rebuilding `Vec<Vec<u32>>` for incoming reverse edges on every chunk.
912/// Instead, incoming reverse edges are accumulated into a CSR-like flat buffer.
913///
914/// Layout:
915/// - incoming_counts[u] = number of reverse edges targeting node u in this chunk
916/// - incoming_offsets[u]..incoming_offsets[u+1] is u's segment in incoming_flat
917/// - affected_nodes stores exactly the nodes touched in this chunk, so we do not
918///   scan all `n` nodes during merge.
919#[derive(Debug)]
920struct MergeScratch {
921    incoming_counts: Vec<usize>,
922    incoming_offsets: Vec<usize>,
923    incoming_write: Vec<usize>,
924    incoming_flat: Vec<u32>,
925
926    affected_marks: Vec<u32>,
927    affected_epoch: u32,
928    affected_nodes: Vec<usize>,
929}
930
931impl MergeScratch {
932    fn new(n: usize) -> Self {
933        Self {
934            incoming_counts: vec![0usize; n],
935            incoming_offsets: vec![0usize; n + 1],
936            incoming_write: vec![0usize; n],
937            incoming_flat: Vec::new(),
938            affected_marks: vec![0u32; n],
939            affected_epoch: 1,
940            affected_nodes: Vec::new(),
941        }
942    }
943
944    #[inline]
945    fn reset(&mut self) {
946        self.affected_epoch = self.affected_epoch.wrapping_add(1);
947        if self.affected_epoch == 0 {
948            self.affected_marks.fill(0);
949            self.affected_epoch = 1;
950        }
951        self.affected_nodes.clear();
952        self.incoming_flat.clear();
953    }
954
955    #[inline]
956    fn mark_affected(&mut self, u: usize) {
957        if self.affected_marks[u] != self.affected_epoch {
958            self.affected_marks[u] = self.affected_epoch;
959            self.affected_nodes.push(u);
960            self.incoming_counts[u] = 0;
961        }
962    }
963}
964
965/// Build a Vamana-like graph using a micro-batched practical DiskANN strategy,
966/// with reusable scratch both for per-thread search state and for chunk merge state.
967fn build_vamana_graph<T, D>(
968    vectors: &FlatVectors<T>,
969    max_degree: usize,
970    build_beam_width: usize,
971    alpha: f32,
972    passes: usize,
973    extra_seeds: usize,
974    dist: D,
975    medoid_id: u32,
976) -> Vec<Vec<u32>>
977where
978    T: bytemuck::Pod + Copy + Send + Sync,
979    D: Distance<T> + Copy + Sync,
980{
981    let n = vectors.n;
982    let mut graph = vec![Vec::<u32>::new(); n];
983    // Bootstrap with a random R-out directed graph.
984    {
985        let mut rng = thread_rng();
986        let target = max_degree.min(n.saturating_sub(1));
987
988        for i in 0..n {
989            let mut s = HashSet::with_capacity(target);
990            while s.len() < target {
991                let nb = rng.gen_range(0..n);
992                if nb != i {
993                    s.insert(nb as u32);
994                }
995            }
996            graph[i] = s.into_iter().collect();
997        }
998    }
999
1000    let passes = passes.max(1);
1001    let mut rng = thread_rng();
1002    let slack_limit = ((GRAPH_SLACK_FACTOR * max_degree as f32).ceil() as usize).max(max_degree);
1003
1004    // Reused across all chunks in all passes.
1005    let mut merge_scratch = MergeScratch::new(n);
1006
1007    for pass_idx in 0..passes {
1008        let pass_alpha = if passes == 1 {
1009            alpha
1010        } else if pass_idx == 0 {
1011            1.0
1012        } else {
1013            alpha
1014        };
1015
1016        let mut order: Vec<usize> = (0..n).collect();
1017        order.shuffle(&mut rng);
1018
1019        for chunk in order.chunks(MICRO_BATCH_CHUNK_SIZE) {
1020            let snapshot = &graph;
1021            // Compute new outgoing lists for this chunk in parallel.
1022            let chunk_results: Vec<(usize, Vec<u32>)> = chunk
1023                .par_iter()
1024                .map_init(
1025                    || IncrementalInsertScratch::new(n, build_beam_width, max_degree, extra_seeds),
1026                    |scratch, &u| {
1027                        let bs = &mut scratch.build;
1028                        bs.candidates.clear();
1029
1030                        // Start from current adjacency.
1031                        for &nb in &snapshot[u] {
1032                            let d = dist.eval(vectors.row(u), vectors.row(nb as usize));
1033                            bs.candidates.push((nb, d));
1034                        }
1035
1036                        // Seed list: medoid + distinct random starts.
1037                        bs.seeds.clear();
1038                        bs.seeds.push(medoid_id as usize);
1039
1040                        let mut local_rng = thread_rng();
1041                        while bs.seeds.len() < 1 + extra_seeds {
1042                            let s = local_rng.gen_range(0..n);
1043                            if !bs.seeds.contains(&s) {
1044                                bs.seeds.push(s);
1045                            }
1046                        }
1047
1048                        let seeds_len = bs.seeds.len();
1049                        for si in 0..seeds_len {
1050                            let start = bs.seeds[si];
1051
1052                            greedy_search_visited_collect(
1053                                vectors.row(u),
1054                                vectors,
1055                                snapshot,
1056                                start,
1057                                build_beam_width,
1058                                dist,
1059                                bs,
1060                            );
1061
1062                            for i in 0..bs.visited_ids.len() {
1063                                bs.candidates.push((bs.visited_ids[i], bs.visited_dists[i]));
1064                            }
1065                        }
1066
1067                        dedup_keep_best_by_id_in_place(&mut bs.candidates);
1068
1069                        let pruned = prune_neighbors(
1070                            u,
1071                            &bs.candidates,
1072                            vectors,
1073                            max_degree,
1074                            pass_alpha,
1075                            dist,
1076                        );
1077
1078                        (u, pruned)
1079                    },
1080                )
1081                .collect();
1082
1083            let mut chunk_nodes = Vec::<usize>::with_capacity(chunk_results.len());
1084            let mut chunk_pruned = Vec::<Vec<u32>>::with_capacity(chunk_results.len());
1085
1086            for (u, pruned) in chunk_results {
1087                chunk_nodes.push(u);
1088                chunk_pruned.push(pruned);
1089            }
1090            // Merge chunk back into graph using reusable CSR-style scratch.
1091            merge_chunk_updates_into_graph_reuse(
1092                &mut graph,
1093                &chunk_nodes,
1094                &chunk_pruned,
1095                vectors,
1096                max_degree,
1097                slack_limit,
1098                pass_alpha,
1099                dist,
1100                &mut merge_scratch,
1101            );
1102        }
1103    }
1104
1105    // Final cleanup: enforce bounded degree and deduplication.
1106    graph
1107        .into_par_iter()
1108        .enumerate()
1109        .map(|(u, neigh)| {
1110            if neigh.len() <= max_degree {
1111                return neigh;
1112            }
1113
1114            let mut ids = neigh;
1115            ids.sort_unstable();
1116            ids.dedup();
1117
1118            let pool: Vec<(u32, f32)> = ids
1119                .into_iter()
1120                .filter(|&id| id as usize != u)
1121                .map(|id| (id, dist.eval(vectors.row(u), vectors.row(id as usize))))
1122                .collect();
1123
1124            prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
1125        })
1126        .collect()
1127}
1128
1129/// Build-time greedy search:
1130/// - dense visited marks instead of HashMap/HashSet
1131/// - visited_ids + visited_dists instead of recomputing distances later
1132/// - ordered beams instead of BinaryHeap
1133/// Output is written into `scratch.visited_ids` and `scratch.visited_dists`.
1134fn greedy_search_visited_collect<T, D>(
1135    query: &[T],
1136    vectors: &FlatVectors<T>,
1137    graph: &[Vec<u32>],
1138    start_id: usize,
1139    beam_width: usize,
1140    dist: D,
1141    scratch: &mut BuildScratch,
1142) where
1143    T: bytemuck::Pod + Copy + Send + Sync,
1144    D: Distance<T> + Copy,
1145{
1146    scratch.reset_search();
1147
1148    let start_dist = dist.eval(query, vectors.row(start_id));
1149    let start = Candidate {
1150        dist: start_dist,
1151        id: start_id as u32,
1152    };
1153
1154    scratch.frontier.insert_unbounded(start);
1155    scratch.work.insert_capped(start, beam_width);
1156    scratch.mark_with_dist(start_id, start_dist);
1157
1158    while !scratch.frontier.is_empty() {
1159        let best = scratch.frontier.best().unwrap();
1160        if scratch.work.len() >= beam_width {
1161            if let Some(worst) = scratch.work.worst() {
1162                if best.dist >= worst.dist {
1163                    break;
1164                }
1165            }
1166        }
1167
1168        let cur = scratch.frontier.pop_best().unwrap();
1169
1170        for &nb in &graph[cur.id as usize] {
1171            let nb_usize = nb as usize;
1172            if scratch.is_marked(nb_usize) {
1173                continue;
1174            }
1175
1176            let d = dist.eval(query, vectors.row(nb_usize));
1177            scratch.mark_with_dist(nb_usize, d);
1178
1179            let cand = Candidate { dist: d, id: nb };
1180
1181            if scratch.work.len() < beam_width {
1182                scratch.work.insert_unbounded(cand);
1183                scratch.frontier.insert_unbounded(cand);
1184            } else if let Some(worst) = scratch.work.worst() {
1185                if d < worst.dist {
1186                    scratch.work.insert_capped(cand, beam_width);
1187                    scratch.frontier.insert_unbounded(cand);
1188                }
1189            }
1190        }
1191    }
1192}
1193
1194/// α-pruning
1195fn prune_neighbors<T, D>(
1196    node_id: usize,
1197    candidates: &[(u32, f32)],
1198    vectors: &FlatVectors<T>,
1199    max_degree: usize,
1200    alpha: f32,
1201    dist: D,
1202) -> Vec<u32>
1203where
1204    T: bytemuck::Pod + Copy + Send + Sync,
1205    D: Distance<T> + Copy,
1206{
1207    if candidates.is_empty() || max_degree == 0 {
1208        return Vec::new();
1209    }
1210
1211    // Sort by distance from node_id, nearest first.
1212    let mut sorted = candidates.to_vec();
1213    sorted.sort_by(|a, b| a.1.total_cmp(&b.1));
1214
1215    // Remove self and duplicate ids while keeping the best (nearest) occurrence.
1216    let mut uniq = Vec::<(u32, f32)>::with_capacity(sorted.len());
1217    let mut last_id: Option<u32> = None;
1218    for &(cand_id, cand_dist) in &sorted {
1219        if cand_id as usize == node_id {
1220            continue;
1221        }
1222        if last_id == Some(cand_id) {
1223            continue;
1224        }
1225        uniq.push((cand_id, cand_dist));
1226        last_id = Some(cand_id);
1227    }
1228
1229    let mut pruned = Vec::<u32>::with_capacity(max_degree);
1230
1231    // Pure robust pruning: DO NOT backfill rejected candidates.
1232    for &(cand_id, cand_dist_to_node) in &uniq {
1233        let mut occluded = false;
1234
1235        for &sel_id in &pruned {
1236            let d_cand_sel = dist.eval(
1237                vectors.row(cand_id as usize),
1238                vectors.row(sel_id as usize),
1239            );
1240
1241            // If selected neighbor sel_id "covers" cand_id, reject cand_id.
1242            if alpha * d_cand_sel <= cand_dist_to_node {
1243                occluded = true;
1244                break;
1245            }
1246        }
1247
1248        if !occluded {
1249            pruned.push(cand_id);
1250            if pruned.len() >= max_degree {
1251                break;
1252            }
1253        }
1254    }
1255
1256    pruned
1257}
1258
1259#[cfg(test)]
1260mod tests {
1261    use super::*;
1262    use anndists::dist::{DistCosine, DistL2};
1263    use rand::Rng;
1264    use std::fs;
1265
1266    fn euclid(a: &[f32], b: &[f32]) -> f32 {
1267        a.iter()
1268            .zip(b)
1269            .map(|(x, y)| (x - y) * (x - y))
1270            .sum::<f32>()
1271            .sqrt()
1272    }
1273
1274    #[test]
1275    fn test_small_index_l2() {
1276        let path = "test_small_l2.db";
1277        let _ = fs::remove_file(path);
1278
1279        let vectors = vec![
1280            vec![0.0, 0.0],
1281            vec![1.0, 0.0],
1282            vec![0.0, 1.0],
1283            vec![1.0, 1.0],
1284            vec![0.5, 0.5],
1285        ];
1286
1287        let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1288
1289        let q = vec![0.1, 0.1];
1290        let nns = index.search(&q, 3, 8);
1291        assert_eq!(nns.len(), 3);
1292
1293        let v = index.get_vector(nns[0] as usize);
1294        assert!(euclid(&q, &v) < 1.0);
1295
1296        let _ = fs::remove_file(path);
1297    }
1298
1299    #[test]
1300    fn test_cosine() {
1301        let path = "test_cosine.db";
1302        let _ = fs::remove_file(path);
1303
1304        let vectors = vec![
1305            vec![1.0, 0.0, 0.0],
1306            vec![0.0, 1.0, 0.0],
1307            vec![0.0, 0.0, 1.0],
1308            vec![1.0, 1.0, 0.0],
1309            vec![1.0, 0.0, 1.0],
1310        ];
1311
1312        let index =
1313            DiskANN::<f32, DistCosine>::build_index_default(&vectors, DistCosine, path).unwrap();
1314
1315        let q = vec![2.0, 0.0, 0.0];
1316        let nns = index.search(&q, 2, 8);
1317        assert_eq!(nns.len(), 2);
1318
1319        let v = index.get_vector(nns[0] as usize);
1320        let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
1321        let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1322        let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
1323        let cos = dot / (n1 * n2);
1324        assert!(cos > 0.7);
1325
1326        let _ = fs::remove_file(path);
1327    }
1328
1329    #[test]
1330    fn test_persistence_and_open() {
1331        let path = "test_persist.db";
1332        let _ = fs::remove_file(path);
1333
1334        let vectors = vec![
1335            vec![0.0, 0.0],
1336            vec![1.0, 0.0],
1337            vec![0.0, 1.0],
1338            vec![1.0, 1.0],
1339        ];
1340
1341        {
1342            let _idx =
1343                DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1344        }
1345
1346        let idx2 = DiskANN::<f32, DistL2>::open_index_default_metric(path).unwrap();
1347        assert_eq!(idx2.num_vectors, 4);
1348        assert_eq!(idx2.dim, 2);
1349
1350        let q = vec![0.9, 0.9];
1351        let res = idx2.search(&q, 2, 8);
1352        assert_eq!(res[0], 3);
1353
1354        let _ = fs::remove_file(path);
1355    }
1356
1357    #[test]
1358    fn test_grid_connectivity() {
1359        let path = "test_grid.db";
1360        let _ = fs::remove_file(path);
1361
1362        let mut vectors = Vec::new();
1363        for i in 0..5 {
1364            for j in 0..5 {
1365                vectors.push(vec![i as f32, j as f32]);
1366            }
1367        }
1368
1369        let index = DiskANN::<f32, DistL2>::build_index_with_params(
1370            &vectors,
1371            DistL2,
1372            path,
1373            DiskAnnParams {
1374                max_degree: 4,
1375                build_beam_width: 64,
1376                alpha: 1.5,
1377                passes: DISKANN_DEFAULT_PASSES,
1378                extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1379            },
1380        )
1381        .unwrap();
1382
1383        for target in 0..vectors.len() {
1384            let q = &vectors[target];
1385            let nns = index.search(q, 10, 32);
1386            if !nns.contains(&(target as u32)) {
1387                let v = index.get_vector(nns[0] as usize);
1388                assert!(euclid(q, &v) < 2.0);
1389            }
1390            for &nb in nns.iter().take(5) {
1391                let v = index.get_vector(nb as usize);
1392                assert!(euclid(q, &v) < 5.0);
1393            }
1394        }
1395
1396        let _ = fs::remove_file(path);
1397    }
1398
1399    #[test]
1400    fn test_medium_random() {
1401        let path = "test_medium.db";
1402        let _ = fs::remove_file(path);
1403
1404        let n = 200usize;
1405        let d = 32usize;
1406        let mut rng = rand::thread_rng();
1407        let vectors: Vec<Vec<f32>> = (0..n)
1408            .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
1409            .collect();
1410
1411        let index = DiskANN::<f32, DistL2>::build_index_with_params(
1412            &vectors,
1413            DistL2,
1414            path,
1415            DiskAnnParams {
1416                max_degree: 32,
1417                build_beam_width: 128,
1418                alpha: 1.2,
1419                passes: DISKANN_DEFAULT_PASSES,
1420                extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1421            },
1422        )
1423        .unwrap();
1424
1425        let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1426        let res = index.search(&q, 10, 64);
1427        assert_eq!(res.len(), 10);
1428
1429        let dists: Vec<f32> = res
1430            .iter()
1431            .map(|&id| {
1432                let v = index.get_vector(id as usize);
1433                euclid(&q, &v)
1434            })
1435            .collect();
1436        let mut sorted = dists.clone();
1437        sorted.sort_by(|a, b| a.total_cmp(b));
1438        assert_eq!(dists, sorted);
1439
1440        let _ = fs::remove_file(path);
1441    }
1442}