Skip to main content

rust_diskann/
lib.rs

1//! # DiskAnn (generic over `anndists::Distance<T>`)
2//!
3//! An on-disk DiskANN library that:
4//! - Builds a Vamana-style graph (greedy + α-pruning) in memory
5//! - Writes vectors + fixed-degree adjacency to a single file
6//! - Memory-maps the file for low-overhead reads
7//! - Is **generic over any Distance<T>** from `anndists` (e.g. L2 on `f32`, Cosine on `f32`,
8//!   Hamming on `u64`, …)
9//!
10//! ## Example (f32 + L2)
11//! ```no_run
12//! use anndists::dist::DistL2;
13//! use rust_diskann::{DiskANN, DiskAnnParams};
14//!
15//! let vectors: Vec<Vec<f32>> = vec![vec![0.0; 128]; 1000];
16//! let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, "index.db").unwrap();
17//!
18//! let q = vec![0.0; 128];
19//! let nns = index.search(&q, 10, 64);
20//! ```
21//!
22//! ## Example (u64 + Hamming)
23//! ```no_run
24//! use anndists::dist::DistHamming;
25//! use rust_diskann::{DiskANN, DiskAnnParams};
26//! let index: Vec<Vec<u64>> = vec![vec![0u64; 128]; 1000];
27//! let idx = DiskANN::<u64, DistHamming>::build_index_default(&index, DistHamming, "mh.db").unwrap();
28//! let q = vec![0u64; 128];
29//! let _ = idx.search(&q, 10, 64);
30//! ```
31//!
32//! ## File Layout
33//! [ metadata_len:u64 ][ metadata (bincode) ][ padding up to vectors_offset ]
34//! [ vectors (num * dim * T) ][ adjacency (num * max_degree * u32) ]
35//!
36//! `vectors_offset` is a fixed 1 MiB gap by default.
37
38use anndists::prelude::Distance;
39use memmap2::Mmap;
40use rand::{prelude::*, thread_rng};
41use rayon::prelude::*;
42use serde::{Deserialize, Serialize};
43use std::cmp::{Ordering, Reverse};
44use std::collections::{BinaryHeap, HashSet};
45use std::fs::OpenOptions;
46use std::io::{Read, Seek, SeekFrom, Write};
47use std::marker::PhantomData;
48use thiserror::Error;
49
50/// Padding sentinel for adjacency slots (avoid colliding with node 0).
51const PAD_U32: u32 = u32::MAX;
52
53/// Defaults for in-memory DiskANN builds
54pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
55pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
56pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
57/// Default number of refinement passes during graph build
58pub const DISKANN_DEFAULT_PASSES: usize = 2;
59/// Default number of extra random seeds per node per pass during graph build
60pub const DISKANN_DEFAULT_EXTRA_SEEDS: usize = 2;
61
62/// Optional bag of knobs if you want to override just a few.
63#[derive(Clone, Copy, Debug)]
64pub struct DiskAnnParams {
65    pub max_degree: usize,
66    pub build_beam_width: usize,
67    pub alpha: f32,
68    /// Number of refinement passes over the graph (>=1).
69    pub passes: usize,
70    /// Extra random seeds per node during each pass (>=0).
71    pub extra_seeds: usize,
72}
73
74impl Default for DiskAnnParams {
75    fn default() -> Self {
76        Self {
77            max_degree: DISKANN_DEFAULT_MAX_DEGREE,
78            build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
79            alpha: DISKANN_DEFAULT_ALPHA,
80            passes: DISKANN_DEFAULT_PASSES,
81            extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
82        }
83    }
84}
85
86/// Custom error type for DiskAnn operations
87#[derive(Debug, Error)]
88pub enum DiskAnnError {
89    /// Represents I/O errors during file operations
90    #[error("I/O error: {0}")]
91    Io(#[from] std::io::Error),
92
93    /// Represents serialization/deserialization errors
94    #[error("Serialization error: {0}")]
95    Bincode(#[from] bincode::Error),
96
97    /// Represents index-specific errors
98    #[error("Index error: {0}")]
99    IndexError(String),
100}
101
102/// Internal metadata structure stored in the index file
103#[derive(Serialize, Deserialize, Debug)]
104struct Metadata {
105    dim: usize,
106    num_vectors: usize,
107    max_degree: usize,
108    medoid_id: u32,
109    vectors_offset: u64,
110    adjacency_offset: u64,
111    elem_size: u8,
112    distance_name: String,
113}
114
115/// Candidate for search/frontier queues
116#[derive(Clone, Copy)]
117struct Candidate {
118    dist: f32,
119    id: u32,
120}
121impl PartialEq for Candidate {
122    fn eq(&self, other: &Self) -> bool {
123        self.dist == other.dist && self.id == other.id
124    }
125}
126impl Eq for Candidate {}
127impl PartialOrd for Candidate {
128    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
129        self.dist.partial_cmp(&other.dist)
130    }
131}
132impl Ord for Candidate {
133    fn cmp(&self, other: &Self) -> Ordering {
134        self.partial_cmp(other).unwrap_or(Ordering::Equal)
135    }
136}
137
138/// Main struct representing a DiskANN index (generic over vector element `T` and distance `D`)
139pub struct DiskANN<T, D>
140where
141    T: bytemuck::Pod + Copy + Send + Sync + 'static,
142    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
143{
144    /// Dimensionality of vectors in the index
145    pub dim: usize,
146    /// Number of vectors in the index
147    pub num_vectors: usize,
148    /// Maximum number of edges per node
149    pub max_degree: usize,
150    /// Informational: type name of the distance (from metadata)
151    pub distance_name: String,
152
153    /// ID of the medoid (used as entry point)
154    medoid_id: u32,
155    // Offsets
156    vectors_offset: u64,
157    adjacency_offset: u64,
158
159    /// Memory-mapped file
160    mmap: Mmap,
161
162    /// The distance strategy
163    dist: D,
164
165    /// keep `T` in the type so the compiler knows about it
166    _phantom: PhantomData<T>,
167}
168
169// constructors
170
171impl<T, D> DiskANN<T, D>
172where
173    T: bytemuck::Pod + Copy + Send + Sync + 'static,
174    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
175{
176    /// Build with default parameters: (M=64, L=128, alpha=1.2, passes=2, extra_seeds=2).
177    pub fn build_index_default(
178        vectors: &[Vec<T>],
179        dist: D,
180        file_path: &str,
181    ) -> Result<Self, DiskAnnError> {
182        Self::build_index(
183            vectors,
184            DISKANN_DEFAULT_MAX_DEGREE,
185            DISKANN_DEFAULT_BUILD_BEAM,
186            DISKANN_DEFAULT_ALPHA,
187            DISKANN_DEFAULT_PASSES,
188            DISKANN_DEFAULT_EXTRA_SEEDS,
189            dist,
190            file_path,
191        )
192    }
193
194    /// Build with a `DiskAnnParams` bundle.
195    pub fn build_index_with_params(
196        vectors: &[Vec<T>],
197        dist: D,
198        file_path: &str,
199        p: DiskAnnParams,
200    ) -> Result<Self, DiskAnnError> {
201        Self::build_index(
202            vectors,
203            p.max_degree,
204            p.build_beam_width,
205            p.alpha,
206            p.passes,
207            p.extra_seeds,
208            dist,
209            file_path,
210        )
211    }
212
213    /// Opens an existing index file, supplying the distance strategy explicitly.
214    pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
215        let mut file = OpenOptions::new().read(true).write(false).open(path)?;
216
217        // Read metadata length
218        let mut buf8 = [0u8; 8];
219        file.seek(SeekFrom::Start(0))?;
220        file.read_exact(&mut buf8)?;
221        let md_len = u64::from_le_bytes(buf8);
222
223        // Read metadata
224        let mut md_bytes = vec![0u8; md_len as usize];
225        file.read_exact(&mut md_bytes)?;
226        let metadata: Metadata = bincode::deserialize(&md_bytes)?;
227
228        let mmap = unsafe { memmap2::Mmap::map(&file)? };
229
230        // Validate element size vs T
231        let want = std::mem::size_of::<T>() as u8;
232        if metadata.elem_size != want {
233            return Err(DiskAnnError::IndexError(format!(
234                "element size mismatch: file has {}B, T is {}B",
235                metadata.elem_size, want
236            )));
237        }
238
239        // Optional sanity/logging: warn if type differs from recorded name
240        let expected = std::any::type_name::<D>();
241        if metadata.distance_name != expected {
242            eprintln!(
243                "Warning: index recorded distance `{}` but you opened with `{}`",
244                metadata.distance_name, expected
245            );
246        }
247
248        Ok(Self {
249            dim: metadata.dim,
250            num_vectors: metadata.num_vectors,
251            max_degree: metadata.max_degree,
252            distance_name: metadata.distance_name,
253            medoid_id: metadata.medoid_id,
254            vectors_offset: metadata.vectors_offset,
255            adjacency_offset: metadata.adjacency_offset,
256            mmap,
257            dist,
258            _phantom: PhantomData,
259        })
260    }
261}
262
263/// Extra sugar when your distance type implements `Default`.
264impl<T, D> DiskANN<T, D>
265where
266    T: bytemuck::Pod + Copy + Send + Sync + 'static,
267    D: Distance<T> + Default + Send + Sync + Copy + Clone + 'static,
268{
269    /// Build with default params **and** `D::default()` metric.
270    pub fn build_index_default_metric(
271        vectors: &[Vec<T>],
272        file_path: &str,
273    ) -> Result<Self, DiskAnnError> {
274        Self::build_index_default(vectors, D::default(), file_path)
275    }
276
277    /// Open an index using `D::default()` as the distance (matches what you built with).
278    pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
279        Self::open_index_with(path, D::default())
280    }
281}
282
283impl<T, D> DiskANN<T, D>
284where
285    T: bytemuck::Pod + Copy + Send + Sync + 'static,
286    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
287{
288    /// Builds a new index from provided vectors
289    ///
290    /// # Arguments
291    /// * `vectors` - The vectors to index (slice of Vec<T>)
292    /// * `max_degree` - Maximum edges per node (M ~ 24-64+)
293    /// * `build_beam_width` - Construction L (e.g., 128-400)
294    /// * `alpha` - Pruning parameter (1.2–2.0)
295    /// * `passes` - Refinement passes over the graph (>=1)
296    /// * `extra_seeds` - Extra random seeds per node per pass (>=0)
297    /// * `dist` - Any `anndists::Distance<T>`
298    /// * `file_path` - Path of index file
299    pub fn build_index(
300        vectors: &[Vec<T>],
301        max_degree: usize,
302        build_beam_width: usize,
303        alpha: f32,
304        passes: usize,
305        extra_seeds: usize,
306        dist: D,
307        file_path: &str,
308    ) -> Result<Self, DiskAnnError> {
309        if vectors.is_empty() {
310            return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
311        }
312
313        let num_vectors = vectors.len();
314        let dim = vectors[0].len();
315        for (i, v) in vectors.iter().enumerate() {
316            if v.len() != dim {
317                return Err(DiskAnnError::IndexError(format!(
318                    "Vector {} has dimension {} but expected {}",
319                    i,
320                    v.len(),
321                    dim
322                )));
323            }
324        }
325
326        let mut file = OpenOptions::new()
327            .create(true)
328            .write(true)
329            .read(true)
330            .truncate(true)
331            .open(file_path)?;
332
333        // Reserve space for metadata (we'll write it after data)
334        let vectors_offset = 1024 * 1024;
335        // Ensure alignment (1 MiB is aligned for any T, but assert anyway)
336        assert_eq!(
337            (vectors_offset as usize) % std::mem::align_of::<T>(),
338            0,
339            "vectors_offset must be aligned for T"
340        );
341
342        let elem_sz = std::mem::size_of::<T>() as u64;
343        let total_vector_bytes = (num_vectors as u64) * (dim as u64) * elem_sz;
344
345        // Write vectors contiguous (sequential I/O is fastest)
346        file.seek(SeekFrom::Start(vectors_offset as u64))?;
347        for vector in vectors {
348            let bytes = bytemuck::cast_slice::<T, u8>(vector);
349            file.write_all(bytes)?;
350        }
351
352        // Compute medoid using provided distance (parallelized distance eval)
353        let medoid_id = calculate_medoid(vectors, dist);
354
355        // Build Vamana-like graph (stronger refinement, parallel inner loops)
356        let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
357        let graph = build_vamana_graph(
358            vectors,
359            max_degree,
360            build_beam_width,
361            alpha,
362            passes,
363            extra_seeds,
364            dist,
365            medoid_id as u32,
366        );
367
368        // Write adjacency lists (fixed max_degree, pad with PAD_U32)
369        file.seek(SeekFrom::Start(adjacency_offset))?;
370        for neighbors in &graph {
371            let mut padded = neighbors.clone();
372            padded.resize(max_degree, PAD_U32);
373            let bytes = bytemuck::cast_slice::<u32, u8>(&padded);
374            file.write_all(bytes)?;
375        }
376
377        // Write metadata
378        let metadata = Metadata {
379            dim,
380            num_vectors,
381            max_degree,
382            medoid_id: medoid_id as u32,
383            vectors_offset: vectors_offset as u64,
384            adjacency_offset,
385            elem_size: std::mem::size_of::<T>() as u8,
386            distance_name: std::any::type_name::<D>().to_string(),
387        };
388
389        let md_bytes = bincode::serialize(&metadata)?;
390        file.seek(SeekFrom::Start(0))?;
391        let md_len = md_bytes.len() as u64;
392        file.write_all(&md_len.to_le_bytes())?;
393        file.write_all(&md_bytes)?;
394        file.sync_all()?;
395
396        // Memory map the file
397        let mmap = unsafe { memmap2::Mmap::map(&file)? };
398
399        Ok(Self {
400            dim,
401            num_vectors,
402            max_degree,
403            distance_name: metadata.distance_name,
404            medoid_id: metadata.medoid_id,
405            vectors_offset: metadata.vectors_offset,
406            adjacency_offset: metadata.adjacency_offset,
407            mmap,
408            dist,
409            _phantom: PhantomData,
410        })
411    }
412
413    /// Searches the index for nearest neighbors using a best-first beam search.
414    /// Termination rule: continue while the best frontier can still improve the worst in working set.
415    /// Like `search` but also returns the distance for each neighbor.
416    pub fn search_with_dists(&self, query: &[T], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
417        assert_eq!(
418            query.len(),
419            self.dim,
420            "Query dim {} != index dim {}",
421            query.len(),
422            self.dim
423        );
424
425        let mut visited = HashSet::new();
426        let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); // best-first by dist
427        let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); // working set, max-heap by dist
428
429        // seed from medoid
430        let start_dist = self.distance_to(query, self.medoid_id as usize);
431        let start = Candidate {
432            dist: start_dist,
433            id: self.medoid_id,
434        };
435        frontier.push(Reverse(start));
436        w.push(start);
437        visited.insert(self.medoid_id);
438
439        // expand while best frontier can still improve worst in working set
440        while let Some(Reverse(best)) = frontier.peek().copied() {
441            if w.len() >= beam_width {
442                if let Some(worst) = w.peek() {
443                    if best.dist >= worst.dist {
444                        break;
445                    }
446                }
447            }
448            let Reverse(current) = frontier.pop().unwrap();
449
450            for &nb in self.get_neighbors(current.id) {
451                if nb == PAD_U32 {
452                    continue;
453                }
454                if !visited.insert(nb) {
455                    continue;
456                }
457
458                let d = self.distance_to(query, nb as usize);
459                let cand = Candidate { dist: d, id: nb };
460
461                if w.len() < beam_width {
462                    w.push(cand);
463                    frontier.push(Reverse(cand));
464                } else if d < w.peek().unwrap().dist {
465                    w.pop();
466                    w.push(cand);
467                    frontier.push(Reverse(cand));
468                }
469            }
470        }
471
472        // top-k by distance, keep distances
473        let mut results: Vec<_> = w.into_vec();
474        results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
475        results.truncate(k);
476        results.into_iter().map(|c| (c.id, c.dist)).collect()
477    }
478
479    /// search but only return neighbor ids
480    pub fn search(&self, query: &[T], k: usize, beam_width: usize) -> Vec<u32> {
481        self.search_with_dists(query, k, beam_width)
482            .into_iter()
483            .map(|(id, _dist)| id)
484            .collect()
485    }
486
487    /// Gets the neighbors of a node from the (fixed-degree) adjacency region
488    fn get_neighbors(&self, node_id: u32) -> &[u32] {
489        let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
490        let start = offset as usize;
491        let end = start + (self.max_degree * 4);
492        let bytes = &self.mmap[start..end];
493        bytemuck::cast_slice(bytes)
494    }
495
496    /// Computes distance between `query` and vector `idx`
497    fn distance_to(&self, query: &[T], idx: usize) -> f32 {
498        let elem_sz = std::mem::size_of::<T>();
499        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
500        let start = offset as usize;
501        let end = start + (self.dim * elem_sz);
502        let bytes = &self.mmap[start..end];
503        let vector: &[T] = bytemuck::cast_slice(bytes);
504        self.dist.eval(query, vector)
505    }
506
507    /// Gets a vector from the index
508    pub fn get_vector(&self, idx: usize) -> Vec<T> {
509        let elem_sz = std::mem::size_of::<T>();
510        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
511        let start = offset as usize;
512        let end = start + (self.dim * elem_sz);
513        let bytes = &self.mmap[start..end];
514        let vector: &[T] = bytemuck::cast_slice(bytes);
515        vector.to_vec()
516    }
517}
518
519/// Calculates the medoid (vector closest to a small pivot set) using distance `D`
520/// Parallelizes the per-vector distance evaluations.
521fn calculate_medoid<T, D>(vectors: &[Vec<T>], dist: D) -> usize
522where
523    T: bytemuck::Pod + Copy + Send + Sync,
524    D: Distance<T> + Copy + Sync,
525{
526    let n = vectors.len();
527    let k = 8.min(n); // lightweight approximation
528    let mut rng = thread_rng();
529    let pivots: Vec<usize> = (0..k).map(|_| rng.gen_range(0..n)).collect();
530
531    let (best_idx, _best_score) = (0..n)
532        .into_par_iter()
533        .map(|i| {
534            let score: f32 = pivots
535                .iter()
536                .map(|&p| dist.eval(&vectors[i], &vectors[p]))
537                .sum();
538            (i, score)
539        })
540        .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
541
542    best_idx
543}
544
545/// Builds a strengthened Vamana-like graph using multi-pass refinement.
546/// - Multi-seed candidate gathering (medoid and random seeds)
547/// - Union with current adjacency before α-prune
548/// - `passes` refinement passes with symmetrization after each pass
549fn build_vamana_graph<T, D>(
550    vectors: &[Vec<T>],
551    max_degree: usize,
552    build_beam_width: usize,
553    alpha: f32,
554    passes: usize,
555    extra_seeds: usize,
556    dist: D,
557    medoid_id: u32,
558) -> Vec<Vec<u32>>
559where
560    T: bytemuck::Pod + Copy + Send + Sync,
561    D: Distance<T> + Copy + Sync,
562{
563    let n = vectors.len();
564    let mut graph = vec![Vec::<u32>::new(); n];
565
566    // Light random bootstrap to avoid disconnected starts
567    {
568        let mut rng = thread_rng();
569        for i in 0..n {
570            let mut s = HashSet::new();
571            let target = (max_degree / 2).max(2).min(n.saturating_sub(1));
572            while s.len() < target {
573                let nb = rng.gen_range(0..n);
574                if nb != i {
575                    s.insert(nb as u32);
576                }
577            }
578            graph[i] = s.into_iter().collect();
579        }
580    }
581
582    let passes = passes.max(1); // at least one pass is sensible
583
584    let mut rng = thread_rng();
585    for _pass in 0..passes {
586        // Shuffle visit order each pass
587        let mut order: Vec<usize> = (0..n).collect();
588        order.shuffle(&mut rng);
589
590        // Snapshot read of graph for parallel candidate building
591        let snapshot = &graph;
592
593        // Build new neighbor proposals in parallel
594        let new_graph: Vec<Vec<u32>> = order
595            .par_iter()
596            .map(|&u| {
597                let mut candidates: Vec<(u32, f32)> =
598                    Vec::with_capacity(build_beam_width * (2 + extra_seeds));
599
600                // Include current adjacency with distances
601                for &nb in &snapshot[u] {
602                    let d = dist.eval(&vectors[u], &vectors[nb as usize]);
603                    candidates.push((nb, d));
604                }
605
606                // Seeds: always medoid + some random starts
607                let mut seeds = Vec::with_capacity(1 + extra_seeds);
608                seeds.push(medoid_id as usize);
609                let mut trng = thread_rng();
610                for _ in 0..extra_seeds {
611                    seeds.push(trng.gen_range(0..n));
612                }
613
614                // Gather candidates from greedy searches
615                for start in seeds {
616                    let mut part = greedy_search(
617                        &vectors[u],
618                        vectors,
619                        snapshot,
620                        start,
621                        build_beam_width,
622                        dist,
623                    );
624                    candidates.append(&mut part);
625                }
626
627                // Deduplicate by id keeping best distance
628                candidates.sort_by(|a, b| a.0.cmp(&b.0));
629                candidates.dedup_by(|a, b| {
630                    if a.0 == b.0 {
631                        if a.1 < b.1 {
632                            *b = *a;
633                        }
634                        true
635                    } else {
636                        false
637                    }
638                });
639
640                // α-prune around u
641                prune_neighbors(u, &candidates, vectors, max_degree, alpha, dist)
642            })
643            .collect();
644
645        // Symmetrize: union incoming + outgoing, then α-prune again (parallel)
646        let mut pos_of = vec![0usize; n];
647        for (pos, &u) in order.iter().enumerate() {
648            pos_of[u] = pos;
649        }
650
651        // Build incoming as CSR
652        let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
653
654        // Union and prune in parallel
655        graph = (0..n)
656            .into_par_iter()
657            .map(|u| {
658                let ng = &new_graph[pos_of[u]]; // outgoing from this pass
659                let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]]; // incoming to u
660
661                // pool = union(outgoing ∪ incoming) with tiny, cache-friendly ops
662                let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
663                pool_ids.extend_from_slice(ng);
664                pool_ids.extend_from_slice(inc);
665                pool_ids.sort_unstable();
666                pool_ids.dedup();
667
668                // compute distances once, then α-prune
669                let pool: Vec<(u32, f32)> = pool_ids
670                    .into_iter()
671                    .filter(|&id| id as usize != u)
672                    .map(|id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
673                    .collect();
674
675                prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
676            })
677            .collect();
678    }
679
680    // Final cleanup (ensure <= max_degree everywhere)
681    graph
682        .into_par_iter()
683        .enumerate()
684        .map(|(u, neigh)| {
685            if neigh.len() <= max_degree {
686                return neigh;
687            }
688            let pool: Vec<(u32, f32)> = neigh
689                .iter()
690                .map(|&id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
691                .collect();
692            prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
693        })
694        .collect()
695}
696
697/// Greedy search used during construction (read-only on `graph`)
698/// Same termination rule as query-time search.
699fn greedy_search<T, D>(
700    query: &[T],
701    vectors: &[Vec<T>],
702    graph: &[Vec<u32>],
703    start_id: usize,
704    beam_width: usize,
705    dist: D,
706) -> Vec<(u32, f32)>
707where
708    T: bytemuck::Pod + Copy + Send + Sync,
709    D: Distance<T> + Copy,
710{
711    let mut visited = HashSet::new();
712    let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); // min-heap by dist
713    let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); // max-heap by dist
714
715    let start_dist = dist.eval(query, &vectors[start_id]);
716    let start = Candidate {
717        dist: start_dist,
718        id: start_id as u32,
719    };
720    frontier.push(Reverse(start));
721    w.push(start);
722    visited.insert(start_id as u32);
723
724    while let Some(Reverse(best)) = frontier.peek().copied() {
725        if w.len() >= beam_width {
726            if let Some(worst) = w.peek() {
727                if best.dist >= worst.dist {
728                    break;
729                }
730            }
731        }
732        let Reverse(cur) = frontier.pop().unwrap();
733
734        for &nb in &graph[cur.id as usize] {
735            if !visited.insert(nb) {
736                continue;
737            }
738            let d = dist.eval(query, &vectors[nb as usize]);
739            let cand = Candidate { dist: d, id: nb };
740
741            if w.len() < beam_width {
742                w.push(cand);
743                frontier.push(Reverse(cand));
744            } else if d < w.peek().unwrap().dist {
745                w.pop();
746                w.push(cand);
747                frontier.push(Reverse(cand));
748            }
749        }
750    }
751
752    let mut v = w.into_vec();
753    v.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
754    v.into_iter().map(|c| (c.id, c.dist)).collect()
755}
756
757/// α-pruning
758fn prune_neighbors<T, D>(
759    node_id: usize,
760    candidates: &[(u32, f32)],
761    vectors: &[Vec<T>],
762    max_degree: usize,
763    alpha: f32,
764    dist: D,
765) -> Vec<u32>
766where
767    T: bytemuck::Pod + Copy + Send + Sync,
768    D: Distance<T> + Copy,
769{
770    if candidates.is_empty() {
771        return Vec::new();
772    }
773
774    let mut sorted = candidates.to_vec();
775    sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
776
777    let mut pruned = Vec::<u32>::new();
778
779    for &(cand_id, cand_dist) in &sorted {
780        if cand_id as usize == node_id {
781            continue;
782        }
783        let mut ok = true;
784        for &sel in &pruned {
785            let d = dist.eval(&vectors[cand_id as usize], &vectors[sel as usize]);
786            // DiskANN / Vamana-style robust prune:
787            // prune cand_id if alpha * f(u,w) <= f(v,w)
788            if alpha * d <= cand_dist {
789                ok = false;
790                break;
791            }
792        }
793        if ok {
794            pruned.push(cand_id);
795            if pruned.len() >= max_degree {
796                break;
797            }
798        }
799    }
800
801    // fill with closest if still not full
802    for &(cand_id, _) in &sorted {
803        if cand_id as usize == node_id {
804            continue;
805        }
806        if !pruned.contains(&cand_id) {
807            pruned.push(cand_id);
808            if pruned.len() >= max_degree {
809                break;
810            }
811        }
812    }
813
814    pruned
815}
816
817fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
818    // 1) count in-degree per node
819    let mut indeg = vec![0usize; n];
820    for (pos, _u) in order.iter().enumerate() {
821        for &v in &new_graph[pos] {
822            indeg[v as usize] += 1;
823        }
824    }
825    // 2) prefix sums to offsets
826    let mut off = vec![0usize; n + 1];
827    for i in 0..n {
828        off[i + 1] = off[i] + indeg[i];
829    }
830    // 3) fill flat incoming list
831    let mut cur = off.clone();
832    let mut incoming_flat = vec![0u32; off[n]];
833    for (pos, &u) in order.iter().enumerate() {
834        for &v in &new_graph[pos] {
835            let idx = cur[v as usize];
836            incoming_flat[idx] = u as u32;
837            cur[v as usize] += 1;
838        }
839    }
840    (incoming_flat, off)
841}
842
843#[cfg(test)]
844mod tests {
845    use super::*;
846    use anndists::dist::{DistCosine, DistL2};
847    use rand::Rng;
848    use std::fs;
849
850    fn euclid(a: &[f32], b: &[f32]) -> f32 {
851        a.iter()
852            .zip(b)
853            .map(|(x, y)| (x - y) * (x - y))
854            .sum::<f32>()
855            .sqrt()
856    }
857
858    #[test]
859    fn test_small_index_l2() {
860        let path = "test_small_l2.db";
861        let _ = fs::remove_file(path);
862
863        let vectors = vec![
864            vec![0.0, 0.0],
865            vec![1.0, 0.0],
866            vec![0.0, 1.0],
867            vec![1.0, 1.0],
868            vec![0.5, 0.5],
869        ];
870
871        let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
872
873        let q = vec![0.1, 0.1];
874        let nns = index.search(&q, 3, 8);
875        assert_eq!(nns.len(), 3);
876
877        // Verify the first neighbor is quite close
878        let v = index.get_vector(nns[0] as usize);
879        assert!(euclid(&q, &v) < 1.0);
880
881        let _ = fs::remove_file(path);
882    }
883
884    #[test]
885    fn test_cosine() {
886        let path = "test_cosine.db";
887        let _ = fs::remove_file(path);
888
889        let vectors = vec![
890            vec![1.0, 0.0, 0.0],
891            vec![0.0, 1.0, 0.0],
892            vec![0.0, 0.0, 1.0],
893            vec![1.0, 1.0, 0.0],
894            vec![1.0, 0.0, 1.0],
895        ];
896
897        let index =
898            DiskANN::<f32, DistCosine>::build_index_default(&vectors, DistCosine, path).unwrap();
899
900        let q = vec![2.0, 0.0, 0.0]; // parallel to [1,0,0]
901        let nns = index.search(&q, 2, 8);
902        assert_eq!(nns.len(), 2);
903
904        // Top neighbor should have high cosine similarity (close direction)
905        let v = index.get_vector(nns[0] as usize);
906        let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
907        let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
908        let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
909        let cos = dot / (n1 * n2);
910        assert!(cos > 0.7);
911
912        let _ = fs::remove_file(path);
913    }
914
915    #[test]
916    fn test_persistence_and_open() {
917        let path = "test_persist.db";
918        let _ = fs::remove_file(path);
919
920        let vectors = vec![
921            vec![0.0, 0.0],
922            vec![1.0, 0.0],
923            vec![0.0, 1.0],
924            vec![1.0, 1.0],
925        ];
926
927        {
928            let _idx =
929                DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
930        }
931
932        // Use the default-metric opener (D: Default), keeping the same T
933        let idx2 = DiskANN::<f32, DistL2>::open_index_default_metric(path).unwrap();
934        assert_eq!(idx2.num_vectors, 4);
935        assert_eq!(idx2.dim, 2);
936
937        let q = vec![0.9, 0.9];
938        let res = idx2.search(&q, 2, 8);
939        // [1,1] should be best (index 3)
940        assert_eq!(res[0], 3);
941
942        let _ = fs::remove_file(path);
943    }
944
945    #[test]
946    fn test_grid_connectivity() {
947        let path = "test_grid.db";
948        let _ = fs::remove_file(path);
949
950        // 5x5 grid
951        let mut vectors = Vec::new();
952        for i in 0..5 {
953            for j in 0..5 {
954                vectors.push(vec![i as f32, j as f32]);
955            }
956        }
957
958        let index = DiskANN::<f32, DistL2>::build_index_with_params(
959            &vectors,
960            DistL2,
961            path,
962            DiskAnnParams {
963                max_degree: 4,
964                build_beam_width: 64,
965                alpha: 1.5,
966                passes: DISKANN_DEFAULT_PASSES,
967                extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
968            },
969        )
970        .unwrap();
971
972        for target in 0..vectors.len() {
973            let q = &vectors[target];
974            let nns = index.search(q, 10, 32);
975            if !nns.contains(&(target as u32)) {
976                let v = index.get_vector(nns[0] as usize);
977                assert!(euclid(q, &v) < 2.0);
978            }
979            for &nb in nns.iter().take(5) {
980                let v = index.get_vector(nb as usize);
981                assert!(euclid(q, &v) < 5.0);
982            }
983        }
984
985        let _ = fs::remove_file(path);
986    }
987
988    #[test]
989    fn test_medium_random() {
990        let path = "test_medium.db";
991        let _ = fs::remove_file(path);
992
993        let n = 200usize;
994        let d = 32usize;
995        let mut rng = rand::thread_rng();
996        let vectors: Vec<Vec<f32>> = (0..n)
997            .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
998            .collect();
999
1000        let index = DiskANN::<f32, DistL2>::build_index_with_params(
1001            &vectors,
1002            DistL2,
1003            path,
1004            DiskAnnParams {
1005                max_degree: 32,
1006                build_beam_width: 128,
1007                alpha: 1.2,
1008                passes: DISKANN_DEFAULT_PASSES,
1009                extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1010            },
1011        )
1012        .unwrap();
1013
1014        let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1015        let res = index.search(&q, 10, 64);
1016        assert_eq!(res.len(), 10);
1017
1018        // Ensure distances are nondecreasing
1019        let dists: Vec<f32> = res
1020            .iter()
1021            .map(|&id| {
1022                let v = index.get_vector(id as usize);
1023                euclid(&q, &v)
1024            })
1025            .collect();
1026        let mut sorted = dists.clone();
1027        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1028        assert_eq!(dists, sorted);
1029
1030        let _ = fs::remove_file(path);
1031    }
1032}