rust_diskann/
lib.rs

1//! # DiskAnn (generic over `anndists::Distance<T>`)
2//!
3//! An on-disk DiskANN library that:
4//! - Builds a Vamana-style graph (greedy + α-pruning) in memory
5//! - Writes vectors + fixed-degree adjacency to a single file
6//! - Memory-maps the file for low-overhead reads
7//! - Is **generic over any Distance<T>** from `anndists` (e.g. L2 on `f32`, Cosine on `f32`,
8//!   Hamming on `u64`, …)
9//!
10//! ## Example (f32 + L2)
11//! ```no_run
12//! use anndists::dist::DistL2;
13//! use rust_diskann::{DiskANN, DiskAnnParams};
14//!
15//! let vectors: Vec<Vec<f32>> = vec![vec![0.0; 128]; 1000];
16//! let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, "index.db").unwrap();
17//!
18//! let q = vec![0.0; 128];
19//! let nns = index.search(&q, 10, 64);
20//! ```
21//!
22//! ## Example (u64 + Hamming)
23//! ```no_run
24//! use anndists::dist::DistHamming;
25//! use rust_diskann::{DiskANN, DiskAnnParams};
26//! let index: Vec<Vec<u64>> = vec![vec![0u64; 128]; 1000];
27//! let idx = DiskANN::<u64, DistHamming>::build_index_default(&index, DistHamming, "mh.db").unwrap();
28//! let q = vec![0u64; 128];
29//! let _ = idx.search(&q, 10, 64);
30//! ```
31//!
32//! ## File Layout
33//! [ metadata_len:u64 ][ metadata (bincode) ][ padding up to vectors_offset ]
34//! [ vectors (num * dim * T) ][ adjacency (num * max_degree * u32) ]
35//!
36//! `vectors_offset` is a fixed 1 MiB gap by default.
37
38use anndists::prelude::Distance;
39use memmap2::Mmap;
40use rand::{prelude::*, thread_rng};
41use rayon::prelude::*;
42use serde::{Deserialize, Serialize};
43use std::cmp::{Ordering, Reverse};
44use std::collections::{BinaryHeap, HashSet};
45use std::fs::OpenOptions;
46use std::io::{Read, Seek, SeekFrom, Write};
47use std::marker::PhantomData;
48use thiserror::Error;
49
50/// Padding sentinel for adjacency slots (avoid colliding with node 0).
51const PAD_U32: u32 = u32::MAX;
52
53/// Defaults for in-memory DiskANN builds
54pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
55pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
56pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
57
58/// Optional bag of knobs if you want to override just a few.
59#[derive(Clone, Copy, Debug)]
60pub struct DiskAnnParams {
61    pub max_degree: usize,
62    pub build_beam_width: usize,
63    pub alpha: f32,
64}
65impl Default for DiskAnnParams {
66    fn default() -> Self {
67        Self {
68            max_degree: DISKANN_DEFAULT_MAX_DEGREE,
69            build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
70            alpha: DISKANN_DEFAULT_ALPHA,
71        }
72    }
73}
74
75/// Custom error type for DiskAnn operations
76#[derive(Debug, Error)]
77pub enum DiskAnnError {
78    /// Represents I/O errors during file operations
79    #[error("I/O error: {0}")]
80    Io(#[from] std::io::Error),
81
82    /// Represents serialization/deserialization errors
83    #[error("Serialization error: {0}")]
84    Bincode(#[from] bincode::Error),
85
86    /// Represents index-specific errors
87    #[error("Index error: {0}")]
88    IndexError(String),
89}
90
91/// Internal metadata structure stored in the index file
92#[derive(Serialize, Deserialize, Debug)]
93struct Metadata {
94    dim: usize,
95    num_vectors: usize,
96    max_degree: usize,
97    medoid_id: u32,
98    vectors_offset: u64,
99    adjacency_offset: u64,
100    elem_size: u8,
101    distance_name: String,
102}
103
104/// Candidate for search/frontier queues
105#[derive(Clone, Copy)]
106struct Candidate {
107    dist: f32,
108    id: u32,
109}
110impl PartialEq for Candidate {
111    fn eq(&self, other: &Self) -> bool {
112        self.dist == other.dist && self.id == other.id
113    }
114}
115impl Eq for Candidate {}
116impl PartialOrd for Candidate {
117    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
118        self.dist.partial_cmp(&other.dist)
119    }
120}
121impl Ord for Candidate {
122    fn cmp(&self, other: &Self) -> Ordering {
123        self.partial_cmp(other).unwrap_or(Ordering::Equal)
124    }
125}
126
127/// Main struct representing a DiskANN index (generic over vector element `T` and distance `D`)
128pub struct DiskANN<T, D>
129where
130    T: bytemuck::Pod + Copy + Send + Sync + 'static,
131    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
132{
133    /// Dimensionality of vectors in the index
134    pub dim: usize,
135    /// Number of vectors in the index
136    pub num_vectors: usize,
137    /// Maximum number of edges per node
138    pub max_degree: usize,
139    /// Informational: type name of the distance (from metadata)
140    pub distance_name: String,
141
142    /// ID of the medoid (used as entry point)
143    medoid_id: u32,
144    // Offsets
145    vectors_offset: u64,
146    adjacency_offset: u64,
147
148    /// Memory-mapped file
149    mmap: Mmap,
150
151    /// The distance strategy
152    dist: D,
153
154    /// keep `T` in the type so the compiler knows about it
155    _phantom: PhantomData<T>,
156}
157
158// constructors
159
160impl<T, D> DiskANN<T, D>
161where
162    T: bytemuck::Pod + Copy + Send + Sync + 'static,
163    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
164{
165    /// Build with default parameters: (M=64, L=128, alpha=1.2).
166    pub fn build_index_default(
167        vectors: &[Vec<T>],
168        dist: D,
169        file_path: &str,
170    ) -> Result<Self, DiskAnnError> {
171        Self::build_index(
172            vectors,
173            DISKANN_DEFAULT_MAX_DEGREE,
174            DISKANN_DEFAULT_BUILD_BEAM,
175            DISKANN_DEFAULT_ALPHA,
176            dist,
177            file_path,
178        )
179    }
180
181    /// Build with a `DiskAnnParams` bundle.
182    pub fn build_index_with_params(
183        vectors: &[Vec<T>],
184        dist: D,
185        file_path: &str,
186        p: DiskAnnParams,
187    ) -> Result<Self, DiskAnnError> {
188        Self::build_index(
189            vectors,
190            p.max_degree,
191            p.build_beam_width,
192            p.alpha,
193            dist,
194            file_path,
195        )
196    }
197
198    /// Opens an existing index file, supplying the distance strategy explicitly.
199    pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
200        let mut file = OpenOptions::new().read(true).write(false).open(path)?;
201
202        // Read metadata length
203        let mut buf8 = [0u8; 8];
204        file.seek(SeekFrom::Start(0))?;
205        file.read_exact(&mut buf8)?;
206        let md_len = u64::from_le_bytes(buf8);
207
208        // Read metadata
209        let mut md_bytes = vec![0u8; md_len as usize];
210        file.read_exact(&mut md_bytes)?;
211        let metadata: Metadata = bincode::deserialize(&md_bytes)?;
212
213        let mmap = unsafe { memmap2::Mmap::map(&file)? };
214
215        // Validate element size vs T
216        let want = std::mem::size_of::<T>() as u8;
217        if metadata.elem_size != want {
218            return Err(DiskAnnError::IndexError(format!(
219                "element size mismatch: file has {}B, T is {}B",
220                metadata.elem_size, want
221            )));
222        }
223
224        // Optional sanity/logging: warn if type differs from recorded name
225        let expected = std::any::type_name::<D>();
226        if metadata.distance_name != expected {
227            eprintln!(
228                "Warning: index recorded distance `{}` but you opened with `{}`",
229                metadata.distance_name, expected
230            );
231        }
232
233        Ok(Self {
234            dim: metadata.dim,
235            num_vectors: metadata.num_vectors,
236            max_degree: metadata.max_degree,
237            distance_name: metadata.distance_name,
238            medoid_id: metadata.medoid_id,
239            vectors_offset: metadata.vectors_offset,
240            adjacency_offset: metadata.adjacency_offset,
241            mmap,
242            dist,
243            _phantom: PhantomData,
244        })
245    }
246}
247
248/// Extra sugar when your distance type implements `Default`.
249impl<T, D> DiskANN<T, D>
250where
251    T: bytemuck::Pod + Copy + Send + Sync + 'static,
252    D: Distance<T> + Default + Send + Sync + Copy + Clone + 'static,
253{
254    /// Build with default params **and** `D::default()` metric.
255    pub fn build_index_default_metric(
256        vectors: &[Vec<T>],
257        file_path: &str,
258    ) -> Result<Self, DiskAnnError> {
259        Self::build_index_default(vectors, D::default(), file_path)
260    }
261
262    /// Open an index using `D::default()` as the distance (matches what you built with).
263    pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
264        Self::open_index_with(path, D::default())
265    }
266}
267
268impl<T, D> DiskANN<T, D>
269where
270    T: bytemuck::Pod + Copy + Send + Sync + 'static,
271    D: Distance<T> + Send + Sync + Copy + Clone + 'static,
272{
273    /// Builds a new index from provided vectors
274    ///
275    /// # Arguments
276    /// * `vectors` - The vectors to index (slice of Vec<T>)
277    /// * `max_degree` - Maximum edges per node (M ~ 24-64+)
278    /// * `build_beam_width` - Construction L (e.g., 128-400)
279    /// * `alpha` - Pruning parameter (1.2–2.0)
280    /// * `dist` - Any `anndists::Distance<T>`
281    /// * `file_path` - Path of index file
282    pub fn build_index(
283        vectors: &[Vec<T>],
284        max_degree: usize,
285        build_beam_width: usize,
286        alpha: f32,
287        dist: D,
288        file_path: &str,
289    ) -> Result<Self, DiskAnnError> {
290        if vectors.is_empty() {
291            return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
292        }
293
294        let num_vectors = vectors.len();
295        let dim = vectors[0].len();
296        for (i, v) in vectors.iter().enumerate() {
297            if v.len() != dim {
298                return Err(DiskAnnError::IndexError(format!(
299                    "Vector {} has dimension {} but expected {}",
300                    i,
301                    v.len(),
302                    dim
303                )));
304            }
305        }
306
307        let mut file = OpenOptions::new()
308            .create(true)
309            .write(true)
310            .read(true)
311            .truncate(true)
312            .open(file_path)?;
313
314        // Reserve space for metadata (we'll write it after data)
315        let vectors_offset = 1024 * 1024;
316        // Ensure alignment (1 MiB is aligned for any T, but assert anyway)
317        assert_eq!(
318            (vectors_offset as usize) % std::mem::align_of::<T>(),
319            0,
320            "vectors_offset must be aligned for T"
321        );
322
323        let elem_sz = std::mem::size_of::<T>() as u64;
324        let total_vector_bytes = (num_vectors as u64) * (dim as u64) * elem_sz;
325
326        // Write vectors contiguous (sequential I/O is fastest)
327        file.seek(SeekFrom::Start(vectors_offset as u64))?;
328        for vector in vectors {
329            let bytes = bytemuck::cast_slice::<T, u8>(vector);
330            file.write_all(bytes)?;
331        }
332
333        // Compute medoid using provided distance (parallelized distance eval)
334        let medoid_id = calculate_medoid(vectors, dist);
335
336        // Build Vamana-like graph (stronger refinement, parallel inner loops)
337        let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
338        let graph = build_vamana_graph(
339            vectors,
340            max_degree,
341            build_beam_width,
342            alpha,
343            dist,
344            medoid_id as u32,
345        );
346
347        // Write adjacency lists (fixed max_degree, pad with PAD_U32)
348        file.seek(SeekFrom::Start(adjacency_offset))?;
349        for neighbors in &graph {
350            let mut padded = neighbors.clone();
351            padded.resize(max_degree, PAD_U32);
352            let bytes = bytemuck::cast_slice::<u32, u8>(&padded);
353            file.write_all(bytes)?;
354        }
355
356        // Write metadata
357        let metadata = Metadata {
358            dim,
359            num_vectors,
360            max_degree,
361            medoid_id: medoid_id as u32,
362            vectors_offset: vectors_offset as u64,
363            adjacency_offset,
364            elem_size: std::mem::size_of::<T>() as u8,
365            distance_name: std::any::type_name::<D>().to_string(),
366        };
367
368        let md_bytes = bincode::serialize(&metadata)?;
369        file.seek(SeekFrom::Start(0))?;
370        let md_len = md_bytes.len() as u64;
371        file.write_all(&md_len.to_le_bytes())?;
372        file.write_all(&md_bytes)?;
373        file.sync_all()?;
374
375        // Memory map the file
376        let mmap = unsafe { memmap2::Mmap::map(&file)? };
377
378        Ok(Self {
379            dim,
380            num_vectors,
381            max_degree,
382            distance_name: metadata.distance_name,
383            medoid_id: metadata.medoid_id,
384            vectors_offset: metadata.vectors_offset,
385            adjacency_offset: metadata.adjacency_offset,
386            mmap,
387            dist,
388            _phantom: PhantomData,
389        })
390    }
391
392    /// Searches the index for nearest neighbors using a best-first beam search.
393    /// Termination rule: continue while the best frontier can still improve the worst in working set.
394    /// Like `search` but also returns the distance for each neighbor.
395    pub fn search_with_dists(&self, query: &[T], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
396        assert_eq!(
397            query.len(),
398            self.dim,
399            "Query dim {} != index dim {}",
400            query.len(),
401            self.dim
402        );
403
404        let mut visited = HashSet::new();
405        let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); // best-first by dist
406        let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); // working set, max-heap by dist
407
408        // seed from medoid
409        let start_dist = self.distance_to(query, self.medoid_id as usize);
410        let start = Candidate {
411            dist: start_dist,
412            id: self.medoid_id,
413        };
414        frontier.push(Reverse(start));
415        w.push(start);
416        visited.insert(self.medoid_id);
417
418        // expand while best frontier can still improve worst in working set
419        while let Some(Reverse(best)) = frontier.peek().copied() {
420            if w.len() >= beam_width {
421                if let Some(worst) = w.peek() {
422                    if best.dist >= worst.dist {
423                        break;
424                    }
425                }
426            }
427            let Reverse(current) = frontier.pop().unwrap();
428
429            for &nb in self.get_neighbors(current.id) {
430                if nb == PAD_U32 {
431                    continue;
432                }
433                if !visited.insert(nb) {
434                    continue;
435                }
436
437                let d = self.distance_to(query, nb as usize);
438                let cand = Candidate { dist: d, id: nb };
439
440                if w.len() < beam_width {
441                    w.push(cand);
442                    frontier.push(Reverse(cand));
443                } else if d < w.peek().unwrap().dist {
444                    w.pop();
445                    w.push(cand);
446                    frontier.push(Reverse(cand));
447                }
448            }
449        }
450
451        // top-k by distance, keep distances
452        let mut results: Vec<_> = w.into_vec();
453        results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
454        results.truncate(k);
455        results.into_iter().map(|c| (c.id, c.dist)).collect()
456    }
457
458    /// search but only return neighbor ids
459    pub fn search(&self, query: &[T], k: usize, beam_width: usize) -> Vec<u32> {
460        self.search_with_dists(query, k, beam_width)
461            .into_iter()
462            .map(|(id, _dist)| id)
463            .collect()
464    }
465
466    /// Gets the neighbors of a node from the (fixed-degree) adjacency region
467    fn get_neighbors(&self, node_id: u32) -> &[u32] {
468        let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
469        let start = offset as usize;
470        let end = start + (self.max_degree * 4);
471        let bytes = &self.mmap[start..end];
472        bytemuck::cast_slice(bytes)
473    }
474
475    /// Computes distance between `query` and vector `idx`
476    fn distance_to(&self, query: &[T], idx: usize) -> f32 {
477        let elem_sz = std::mem::size_of::<T>();
478        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
479        let start = offset as usize;
480        let end = start + (self.dim * elem_sz);
481        let bytes = &self.mmap[start..end];
482        let vector: &[T] = bytemuck::cast_slice(bytes);
483        self.dist.eval(query, vector)
484    }
485
486    /// Gets a vector from the index
487    pub fn get_vector(&self, idx: usize) -> Vec<T> {
488        let elem_sz = std::mem::size_of::<T>();
489        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
490        let start = offset as usize;
491        let end = start + (self.dim * elem_sz);
492        let bytes = &self.mmap[start..end];
493        let vector: &[T] = bytemuck::cast_slice(bytes);
494        vector.to_vec()
495    }
496}
497
498/// Calculates the medoid (vector closest to a small pivot set) using distance `D`
499/// Parallelizes the per-vector distance evaluations.
500fn calculate_medoid<T, D>(vectors: &[Vec<T>], dist: D) -> usize
501where
502    T: bytemuck::Pod + Copy + Send + Sync,
503    D: Distance<T> + Copy + Sync,
504{
505    let n = vectors.len();
506    let k = 8.min(n); // lightweight approximation
507    let mut rng = thread_rng();
508    let pivots: Vec<usize> = (0..k).map(|_| rng.gen_range(0..n)).collect();
509
510    let (best_idx, _best_score) = (0..n)
511        .into_par_iter()
512        .map(|i| {
513            let score: f32 = pivots
514                .iter()
515                .map(|&p| dist.eval(&vectors[i], &vectors[p]))
516                .sum();
517            (i, score)
518        })
519        .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
520
521    best_idx
522}
523
524/// Builds a strengthened Vamana-like graph using multi-pass refinement.
525/// - Multi-seed candidate gathering (medoid + random seeds)
526/// - Union with current adjacency before α-prune
527/// - 2 refinement passes with symmetrization after each pass
528fn build_vamana_graph<T, D>(
529    vectors: &[Vec<T>],
530    max_degree: usize,
531    build_beam_width: usize,
532    alpha: f32,
533    dist: D,
534    medoid_id: u32,
535) -> Vec<Vec<u32>>
536where
537    T: bytemuck::Pod + Copy + Send + Sync,
538    D: Distance<T> + Copy + Sync,
539{
540    let n = vectors.len();
541    let mut graph = vec![Vec::<u32>::new(); n];
542
543    // Light random bootstrap to avoid disconnected starts
544    {
545        let mut rng = thread_rng();
546        for i in 0..n {
547            let mut s = HashSet::new();
548            let target = (max_degree / 2).max(2).min(n.saturating_sub(1));
549            while s.len() < target {
550                let nb = rng.gen_range(0..n);
551                if nb != i {
552                    s.insert(nb as u32);
553                }
554            }
555            graph[i] = s.into_iter().collect();
556        }
557    }
558
559    // Refinement passes
560    const PASSES: usize = 2;
561    const EXTRA_SEEDS: usize = 2;
562
563    let mut rng = thread_rng();
564    for _pass in 0..PASSES {
565        // Shuffle visit order each pass
566        let mut order: Vec<usize> = (0..n).collect();
567        order.shuffle(&mut rng);
568
569        // Snapshot read of graph for parallel candidate building
570        let snapshot = &graph;
571
572        // Build new neighbor proposals in parallel
573        let new_graph: Vec<Vec<u32>> = order
574            .par_iter()
575            .map(|&u| {
576                let mut candidates: Vec<(u32, f32)> =
577                    Vec::with_capacity(build_beam_width * (2 + EXTRA_SEEDS));
578
579                // Include current adjacency with distances
580                for &nb in &snapshot[u] {
581                    let d = dist.eval(&vectors[u], &vectors[nb as usize]);
582                    candidates.push((nb, d));
583                }
584
585                // Seeds: always medoid + some random starts
586                let mut seeds = Vec::with_capacity(1 + EXTRA_SEEDS);
587                seeds.push(medoid_id as usize);
588                let mut trng = thread_rng();
589                for _ in 0..EXTRA_SEEDS {
590                    seeds.push(trng.gen_range(0..n));
591                }
592
593                // Gather candidates from greedy searches
594                for start in seeds {
595                    let mut part =
596                        greedy_search(&vectors[u], vectors, snapshot, start, build_beam_width, dist);
597                    candidates.append(&mut part);
598                }
599
600                // Deduplicate by id keeping best distance
601                candidates.sort_by(|a, b| a.0.cmp(&b.0));
602                candidates.dedup_by(|a, b| {
603                    if a.0 == b.0 {
604                        if a.1 < b.1 {
605                            *b = *a;
606                        }
607                        true
608                    } else {
609                        false
610                    }
611                });
612
613                // α-prune around u
614                prune_neighbors(u, &candidates, vectors, max_degree, alpha, dist)
615            })
616            .collect();
617
618        // Symmetrize: union incoming + outgoing, then α-prune again (parallel)
619        let mut pos_of = vec![0usize; n];
620        for (pos, &u) in order.iter().enumerate() {
621            pos_of[u] = pos;
622        }
623
624        // Build incoming as CSR
625        let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
626
627        // Union and prune in parallel
628        graph = (0..n)
629            .into_par_iter()
630            .map(|u| {
631                let ng = &new_graph[pos_of[u]]; // outgoing from this pass
632                let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]]; // incoming to u
633
634                // pool = union(outgoing ∪ incoming) with tiny, cache-friendly ops
635                let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
636                pool_ids.extend_from_slice(ng);
637                pool_ids.extend_from_slice(inc);
638                pool_ids.sort_unstable();
639                pool_ids.dedup();
640
641                // compute distances once, then α-prune
642                let pool: Vec<(u32, f32)> = pool_ids
643                    .into_iter()
644                    .filter(|&id| id as usize != u)
645                    .map(|id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
646                    .collect();
647
648                prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
649            })
650            .collect();
651    }
652
653    // Final cleanup (ensure <= max_degree everywhere)
654    graph
655        .into_par_iter()
656        .enumerate()
657        .map(|(u, neigh)| {
658            if neigh.len() <= max_degree {
659                return neigh;
660            }
661            let pool: Vec<(u32, f32)> = neigh
662                .iter()
663                .map(|&id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
664                .collect();
665            prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
666        })
667        .collect()
668}
669
670/// Greedy search used during construction (read-only on `graph`)
671/// Same termination rule as query-time search.
672fn greedy_search<T, D>(
673    query: &[T],
674    vectors: &[Vec<T>],
675    graph: &[Vec<u32>],
676    start_id: usize,
677    beam_width: usize,
678    dist: D,
679) -> Vec<(u32, f32)>
680where
681    T: bytemuck::Pod + Copy + Send + Sync,
682    D: Distance<T> + Copy,
683{
684    let mut visited = HashSet::new();
685    let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); // min-heap by dist
686    let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); // max-heap by dist
687
688    let start_dist = dist.eval(query, &vectors[start_id]);
689    let start = Candidate {
690        dist: start_dist,
691        id: start_id as u32,
692    };
693    frontier.push(Reverse(start));
694    w.push(start);
695    visited.insert(start_id as u32);
696
697    while let Some(Reverse(best)) = frontier.peek().copied() {
698        if w.len() >= beam_width {
699            if let Some(worst) = w.peek() {
700                if best.dist >= worst.dist {
701                    break;
702                }
703            }
704        }
705        let Reverse(cur) = frontier.pop().unwrap();
706
707        for &nb in &graph[cur.id as usize] {
708            if !visited.insert(nb) {
709                continue;
710            }
711            let d = dist.eval(query, &vectors[nb as usize]);
712            let cand = Candidate { dist: d, id: nb };
713
714            if w.len() < beam_width {
715                w.push(cand);
716                frontier.push(Reverse(cand));
717            } else if d < w.peek().unwrap().dist {
718                w.pop();
719                w.push(cand);
720                frontier.push(Reverse(cand));
721            }
722        }
723    }
724
725    let mut v = w.into_vec();
726    v.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
727    v.into_iter().map(|c| (c.id, c.dist)).collect()
728}
729
730/// α-pruning
731fn prune_neighbors<T, D>(
732    node_id: usize,
733    candidates: &[(u32, f32)],
734    vectors: &[Vec<T>],
735    max_degree: usize,
736    alpha: f32,
737    dist: D,
738) -> Vec<u32>
739where
740    T: bytemuck::Pod + Copy + Send + Sync,
741    D: Distance<T> + Copy,
742{
743    if candidates.is_empty() {
744        return Vec::new();
745    }
746
747    let mut sorted = candidates.to_vec();
748    sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
749
750    let mut pruned = Vec::<u32>::new();
751
752    for &(cand_id, cand_dist) in &sorted {
753        if cand_id as usize == node_id {
754            continue;
755        }
756        let mut ok = true;
757        for &sel in &pruned {
758            let d = dist.eval(&vectors[cand_id as usize], &vectors[sel as usize]);
759            if d < alpha * cand_dist {
760                ok = false;
761                break;
762            }
763        }
764        if ok {
765            pruned.push(cand_id);
766            if pruned.len() >= max_degree {
767                break;
768            }
769        }
770    }
771
772    // fill with closest if still not full
773    for &(cand_id, _) in &sorted {
774        if cand_id as usize == node_id {
775            continue;
776        }
777        if !pruned.contains(&cand_id) {
778            pruned.push(cand_id);
779            if pruned.len() >= max_degree {
780                break;
781            }
782        }
783    }
784
785    pruned
786}
787
788fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
789    // 1) count in-degree per node
790    let mut indeg = vec![0usize; n];
791    for (pos, _u) in order.iter().enumerate() {
792        for &v in &new_graph[pos] {
793            indeg[v as usize] += 1;
794        }
795    }
796    // 2) prefix sums → offsets
797    let mut off = vec![0usize; n + 1];
798    for i in 0..n {
799        off[i + 1] = off[i] + indeg[i];
800    }
801    // 3) fill flat incoming list
802    let mut cur = off.clone();
803    let mut incoming_flat = vec![0u32; off[n]];
804    for (pos, &u) in order.iter().enumerate() {
805        for &v in &new_graph[pos] {
806            let idx = cur[v as usize];
807            incoming_flat[idx] = u as u32;
808            cur[v as usize] += 1;
809        }
810    }
811    (incoming_flat, off)
812}
813
814#[cfg(test)]
815mod tests {
816    use super::*;
817    use anndists::dist::{DistCosine, DistL2};
818    use rand::Rng;
819    use std::fs;
820
821    fn euclid(a: &[f32], b: &[f32]) -> f32 {
822        a.iter()
823            .zip(b)
824            .map(|(x, y)| (x - y) * (x - y))
825            .sum::<f32>()
826            .sqrt()
827    }
828
829    #[test]
830    fn test_small_index_l2() {
831        let path = "test_small_l2.db";
832        let _ = fs::remove_file(path);
833
834        let vectors = vec![
835            vec![0.0, 0.0],
836            vec![1.0, 0.0],
837            vec![0.0, 1.0],
838            vec![1.0, 1.0],
839            vec![0.5, 0.5],
840        ];
841
842        let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
843
844        let q = vec![0.1, 0.1];
845        let nns = index.search(&q, 3, 8);
846        assert_eq!(nns.len(), 3);
847
848        // Verify the first neighbor is quite close in L2
849        let v = index.get_vector(nns[0] as usize);
850        assert!(euclid(&q, &v) < 1.0);
851
852        let _ = fs::remove_file(path);
853    }
854
855    #[test]
856    fn test_cosine() {
857        let path = "test_cosine.db";
858        let _ = fs::remove_file(path);
859
860        let vectors = vec![
861            vec![1.0, 0.0, 0.0],
862            vec![0.0, 1.0, 0.0],
863            vec![0.0, 0.0, 1.0],
864            vec![1.0, 1.0, 0.0],
865            vec![1.0, 0.0, 1.0],
866        ];
867
868        let index =
869            DiskANN::<f32, DistCosine>::build_index_default(&vectors, DistCosine, path).unwrap();
870
871        let q = vec![2.0, 0.0, 0.0]; // parallel to [1,0,0]
872        let nns = index.search(&q, 2, 8);
873        assert_eq!(nns.len(), 2);
874
875        // Top neighbor should have high cosine similarity (close direction)
876        let v = index.get_vector(nns[0] as usize);
877        let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
878        let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
879        let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
880        let cos = dot / (n1 * n2);
881        assert!(cos > 0.7);
882
883        let _ = fs::remove_file(path);
884    }
885
886    #[test]
887    fn test_persistence_and_open() {
888        let path = "test_persist.db";
889        let _ = fs::remove_file(path);
890
891        let vectors = vec![
892            vec![0.0, 0.0],
893            vec![1.0, 0.0],
894            vec![0.0, 1.0],
895            vec![1.0, 1.0],
896        ];
897
898        {
899            let _idx =
900                DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
901        }
902
903        // Use the default-metric opener (D: Default), keeping the same T
904        let idx2 = DiskANN::<f32, DistL2>::open_index_default_metric(path).unwrap();
905        assert_eq!(idx2.num_vectors, 4);
906        assert_eq!(idx2.dim, 2);
907
908        let q = vec![0.9, 0.9];
909        let res = idx2.search(&q, 2, 8);
910        // [1,1] should be best (index 3)
911        assert_eq!(res[0], 3);
912
913        let _ = fs::remove_file(path);
914    }
915
916    #[test]
917    fn test_grid_connectivity() {
918        let path = "test_grid.db";
919        let _ = fs::remove_file(path);
920
921        // 5x5 grid
922        let mut vectors = Vec::new();
923        for i in 0..5 {
924            for j in 0..5 {
925                vectors.push(vec![i as f32, j as f32]);
926            }
927        }
928
929        let index = DiskANN::<f32, DistL2>::build_index_with_params(
930            &vectors,
931            DistL2,
932            path,
933            DiskAnnParams {
934                max_degree: 4,
935                build_beam_width: 64,
936                alpha: 1.5,
937            },
938        )
939        .unwrap();
940
941        for target in 0..vectors.len() {
942            let q = &vectors[target];
943            let nns = index.search(q, 10, 32);
944            if !nns.contains(&(target as u32)) {
945                let v = index.get_vector(nns[0] as usize);
946                assert!(euclid(q, &v) < 2.0);
947            }
948            for &nb in nns.iter().take(5) {
949                let v = index.get_vector(nb as usize);
950                assert!(euclid(q, &v) < 5.0);
951            }
952        }
953
954        let _ = fs::remove_file(path);
955    }
956
957    #[test]
958    fn test_medium_random() {
959        let path = "test_medium.db";
960        let _ = fs::remove_file(path);
961
962        let n = 200usize;
963        let d = 32usize;
964        let mut rng = rand::thread_rng();
965        let vectors: Vec<Vec<f32>> = (0..n)
966            .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
967            .collect();
968
969        let index = DiskANN::<f32, DistL2>::build_index_with_params(
970            &vectors,
971            DistL2,
972            path,
973            DiskAnnParams {
974                max_degree: 32,
975                build_beam_width: 128,
976                alpha: 1.2,
977            },
978        )
979        .unwrap();
980
981        let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
982        let res = index.search(&q, 10, 64);
983        assert_eq!(res.len(), 10);
984
985        // Ensure distances are nondecreasing
986        let dists: Vec<f32> = res
987            .iter()
988            .map(|&id| {
989                let v = index.get_vector(id as usize);
990                euclid(&q, &v)
991            })
992            .collect();
993        let mut sorted = dists.clone();
994        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
995        assert_eq!(dists, sorted);
996
997        let _ = fs::remove_file(path);
998    }
999}