rust_diskann/
lib.rs

1//! # DiskAnn (generic over `anndists::Distance<f32>`)
2//!
3//! A minimal, on-disk, DiskANN-like library that:
4//! - Builds a Vamana-style graph (greedy + α-pruning) in memory
5//! - Writes vectors + fixed-degree adjacency to a single file
6//! - Memory-maps the file for low-overhead reads
7//! - Is **generic over any Distance<f32>** from `anndists` (L2, Cosine, Hamming, Dot, …)
8//!
9//! ## Example
10//! ```no_run
11//! use anndists::dist::{DistL2, DistCosine};
12//! use diskann_rs::{DiskANN, DiskAnnParams};
13//!
14//! // Build a new index from vectors, using L2 and default params
15//! let vectors = vec![vec![0.0; 128]; 1000];
16//! let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2{}, "index.db").unwrap();
17//!
18//! // Or with custom params
19//! let index2 = DiskANN::<DistCosine>::build_index_with_params(
20//!     &vectors,
21//!     DistCosine{},
22//!     "index_cos.db",
23//!     DiskAnnParams { max_degree: 48, ..Default::default() },
24//! ).unwrap();
25//!
26//! // Search the index
27//! let query = vec![0.0; 128];
28//! let neighbors = index.search(&query, 10, 64);
29//!
30//! // Open later (provide the same distance type)
31//! let _reopened = DiskANN::<DistL2>::open_index_default_metric("index.db").unwrap();
32//! ```
33//!
34//! ## File Layout
35//! [ metadata_len:u64 ][ metadata (bincode) ][ padding up to vectors_offset ]
36//! [ vectors (num * dim * f32) ][ adjacency (num * max_degree * u32) ]
37//!
38//! `vectors_offset` is a fixed 1 MiB gap by default.
39
40use anndists::prelude::Distance;
41use bytemuck;
42use memmap2::Mmap;
43use rand::prelude::*;
44use rayon::prelude::*;
45use serde::{Deserialize, Serialize};
46use std::cmp::{Ordering, Reverse};
47use std::collections::{BinaryHeap, HashSet};
48use std::fs::OpenOptions;
49use std::io::{Read, Seek, SeekFrom, Write};
50use thiserror::Error;
51
52/// Padding sentinel for adjacency slots (avoid colliding with node 0).
53const PAD_U32: u32 = u32::MAX;
54
55/// Defaults for in-memory DiskANN builds
56pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
57pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
58pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
59
60/// Optional bag of knobs if you want to override just a few.
61#[derive(Clone, Copy, Debug)]
62pub struct DiskAnnParams {
63    pub max_degree: usize,
64    pub build_beam_width: usize,
65    pub alpha: f32,
66}
67impl Default for DiskAnnParams {
68    fn default() -> Self {
69        Self {
70            max_degree: DISKANN_DEFAULT_MAX_DEGREE,
71            build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
72            alpha: DISKANN_DEFAULT_ALPHA,
73        }
74    }
75}
76
77/// Custom error type for DiskAnnRS operations
78#[derive(Debug, Error)]
79pub enum DiskAnnError {
80    /// Represents I/O errors during file operations
81    #[error("I/O error: {0}")]
82    Io(#[from] std::io::Error),
83
84    /// Represents serialization/deserialization errors
85    #[error("Serialization error: {0}")]
86    Bincode(#[from] bincode::Error),
87
88    /// Represents index-specific errors
89    #[error("Index error: {0}")]
90    IndexError(String),
91}
92
93/// Internal metadata structure stored in the index file
94#[derive(Serialize, Deserialize, Debug)]
95struct Metadata {
96    dim: usize,
97    num_vectors: usize,
98    max_degree: usize,
99    medoid_id: u32,
100    vectors_offset: u64,
101    adjacency_offset: u64,
102    distance_name: String,
103}
104
105/// Candidate for search/frontier queues
106#[derive(Clone, Copy)]
107struct Candidate {
108    dist: f32,
109    id: u32,
110}
111impl PartialEq for Candidate {
112    fn eq(&self, other: &Self) -> bool {
113        self.dist == other.dist && self.id == other.id
114    }
115}
116impl Eq for Candidate {}
117impl PartialOrd for Candidate {
118    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
119        // Natural order by distance: smaller is "less".
120        self.dist.partial_cmp(&other.dist)
121    }
122}
123impl Ord for Candidate {
124    fn cmp(&self, other: &Self) -> Ordering {
125        self.partial_cmp(other).unwrap_or(Ordering::Equal)
126    }
127}
128
129/// Main struct representing a DiskANN index (generic over distance)
130pub struct DiskANN<D>
131where
132    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
133{
134    /// Dimensionality of vectors in the index
135    pub dim: usize,
136    /// Number of vectors in the index
137    pub num_vectors: usize,
138    /// Maximum number of edges per node
139    pub max_degree: usize,
140    /// Informational: type name of the distance (from metadata)
141    pub distance_name: String,
142
143    /// ID of the medoid (used as entry point)
144    medoid_id: u32,
145    // Offsets
146    vectors_offset: u64,
147    adjacency_offset: u64,
148
149    /// Memory-mapped file
150    mmap: Mmap,
151
152    /// The distance strategy
153    dist: D,
154}
155
156// constructors
157
158impl<D> DiskANN<D>
159where
160    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
161{
162    /// Build with default parameters: (M=32, L=256, alpha=1.2).
163    pub fn build_index_default(
164        vectors: &[Vec<f32>],
165        dist: D,
166        file_path: &str,
167    ) -> Result<Self, DiskAnnError> {
168        Self::build_index(
169            vectors,
170            DISKANN_DEFAULT_MAX_DEGREE,
171            DISKANN_DEFAULT_BUILD_BEAM,
172            DISKANN_DEFAULT_ALPHA,
173            dist,
174            file_path,
175        )
176    }
177
178    /// Build with a `DiskAnnParams` bundle.
179    pub fn build_index_with_params(
180        vectors: &[Vec<f32>],
181        dist: D,
182        file_path: &str,
183        p: DiskAnnParams,
184    ) -> Result<Self, DiskAnnError> {
185        Self::build_index(
186            vectors,
187            p.max_degree,
188            p.build_beam_width,
189            p.alpha,
190            dist,
191            file_path,
192        )
193    }
194}
195
196/// Extra sugar when your distance type implements `Default` (most unit-struct metrics do).
197impl<D> DiskANN<D>
198where
199    D: Distance<f32> + Default + Send + Sync + Copy + Clone + 'static,
200{
201    /// Build with default params **and** `D::default()` metric.
202    pub fn build_index_default_metric(
203        vectors: &[Vec<f32>],
204        file_path: &str,
205    ) -> Result<Self, DiskAnnError> {
206        Self::build_index_default(vectors, D::default(), file_path)
207    }
208
209    /// Open an index using `D::default()` as the distance (matches what you built with).
210    pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
211        Self::open_index_with(path, D::default())
212    }
213}
214
215impl<D> DiskANN<D>
216where
217    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
218{
219    /// Builds a new index from provided vectors
220    ///
221    /// # Arguments
222    /// * `vectors` - The vectors to index (slice of Vec<f32>)
223    /// * `max_degree` - Maximum edges per node (M ~ 24-64)
224    /// * `build_beam_width` - Construction L (e.g., 128-400)
225    /// * `alpha` - Pruning parameter (1.2–2.0)
226    /// * `dist` - Any `anndists::Distance<f32>` (e.g., `DistL2`)
227    /// * `file_path` - Path of index file
228    pub fn build_index(
229        vectors: &[Vec<f32>],
230        max_degree: usize,
231        build_beam_width: usize,
232        alpha: f32,
233        dist: D,
234        file_path: &str,
235    ) -> Result<Self, DiskAnnError> {
236        if vectors.is_empty() {
237            return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
238        }
239
240        let num_vectors = vectors.len();
241        let dim = vectors[0].len();
242        for (i, v) in vectors.iter().enumerate() {
243            if v.len() != dim {
244                return Err(DiskAnnError::IndexError(format!(
245                    "Vector {} has dimension {} but expected {}",
246                    i,
247                    v.len(),
248                    dim
249                )));
250            }
251        }
252
253        let mut file = OpenOptions::new()
254            .create(true)
255            .write(true)
256            .read(true)
257            .truncate(true)
258            .open(file_path)?;
259
260        // Reserve space for metadata (we'll write it after data)
261        let vectors_offset = 1024 * 1024;
262        let total_vector_bytes = (num_vectors as u64) * (dim as u64) * 4;
263
264        // Write vectors contiguous (sequential I/O is fastest)
265        file.seek(SeekFrom::Start(vectors_offset))?;
266        for vector in vectors {
267            let bytes = bytemuck::cast_slice(vector);
268            file.write_all(bytes)?;
269        }
270
271        // Compute medoid using provided distance (parallelized distance eval)
272        let medoid_id = calculate_medoid(vectors, dist);
273
274        // Build Vamana-like graph (stronger refinement, parallel inner loops)
275        let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
276        let graph = build_vamana_graph(
277            vectors,
278            max_degree,
279            build_beam_width,
280            alpha,
281            dist,
282            medoid_id as u32,
283        );
284
285        // Write adjacency lists (fixed max_degree, pad with PAD_U32)
286        file.seek(SeekFrom::Start(adjacency_offset))?;
287        for neighbors in &graph {
288            let mut padded = neighbors.clone();
289            padded.resize(max_degree, PAD_U32);
290            let bytes = bytemuck::cast_slice(&padded);
291            file.write_all(bytes)?;
292        }
293
294        // Write metadata
295        let metadata = Metadata {
296            dim,
297            num_vectors,
298            max_degree,
299            medoid_id: medoid_id as u32,
300            vectors_offset: vectors_offset as u64,
301            adjacency_offset,
302            distance_name: std::any::type_name::<D>().to_string(),
303        };
304
305        let md_bytes = bincode::serialize(&metadata)?;
306        file.seek(SeekFrom::Start(0))?;
307        let md_len = md_bytes.len() as u64;
308        file.write_all(&md_len.to_le_bytes())?;
309        file.write_all(&md_bytes)?;
310        file.sync_all()?;
311
312        // Memory map the file
313        let mmap = unsafe { memmap2::Mmap::map(&file)? };
314
315        Ok(Self {
316            dim,
317            num_vectors,
318            max_degree,
319            distance_name: metadata.distance_name,
320            medoid_id: metadata.medoid_id,
321            vectors_offset: metadata.vectors_offset,
322            adjacency_offset: metadata.adjacency_offset,
323            mmap,
324            dist,
325        })
326    }
327
328    /// Opens an existing index file, supplying the distance strategy explicitly.
329    pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
330        let mut file = OpenOptions::new().read(true).write(false).open(path)?;
331
332        // Read metadata length
333        let mut buf8 = [0u8; 8];
334        file.seek(SeekFrom::Start(0))?;
335        file.read_exact(&mut buf8)?;
336        let md_len = u64::from_le_bytes(buf8);
337
338        // Read metadata
339        let mut md_bytes = vec![0u8; md_len as usize];
340        file.read_exact(&mut md_bytes)?;
341        let metadata: Metadata = bincode::deserialize(&md_bytes)?;
342
343        let mmap = unsafe { memmap2::Mmap::map(&file)? };
344
345        // Optional sanity/logging: warn if type differs from recorded name
346        let expected = std::any::type_name::<D>();
347        if metadata.distance_name != expected {
348            eprintln!(
349                "Warning: index recorded distance `{}` but you opened with `{}`",
350                metadata.distance_name, expected
351            );
352        }
353
354        Ok(Self {
355            dim: metadata.dim,
356            num_vectors: metadata.num_vectors,
357            max_degree: metadata.max_degree,
358            distance_name: metadata.distance_name,
359            medoid_id: metadata.medoid_id,
360            vectors_offset: metadata.vectors_offset,
361            adjacency_offset: metadata.adjacency_offset,
362            mmap,
363            dist,
364        })
365    }
366
367    /// Searches the index for nearest neighbors using a best-first beam search.
368    /// Termination rule: continue while the best frontier can still improve the worst in working set.
369    /// Like `search` but also returns the distance for each neighbor.
370    /// Distances are exactly the ones computed during the beam search.
371    pub fn search_with_dists(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
372        assert_eq!(
373            query.len(),
374            self.dim,
375            "Query dim {} != index dim {}",
376            query.len(),
377            self.dim
378        );
379
380        #[derive(Clone, Copy)]
381        struct Candidate {
382            dist: f32,
383            id: u32,
384        }
385        impl PartialEq for Candidate {
386            fn eq(&self, o: &Self) -> bool {
387                self.dist == o.dist && self.id == o.id
388            }
389        }
390        impl Eq for Candidate {}
391        impl PartialOrd for Candidate {
392            fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
393                self.dist.partial_cmp(&o.dist)
394            }
395        }
396        impl Ord for Candidate {
397            fn cmp(&self, o: &Self) -> Ordering {
398                self.partial_cmp(o).unwrap_or(Ordering::Equal)
399            }
400        }
401
402        let mut visited = HashSet::new();
403        let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); // best-first by dist
404        let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); // working set, max-heap by dist
405
406        // seed from medoid
407        let start_dist = self.distance_to(query, self.medoid_id as usize);
408        let start = Candidate {
409            dist: start_dist,
410            id: self.medoid_id,
411        };
412        frontier.push(Reverse(start));
413        w.push(start);
414        visited.insert(self.medoid_id);
415
416        // expand while best frontier can still improve worst in working set
417        while let Some(Reverse(best)) = frontier.peek().copied() {
418            if w.len() >= beam_width {
419                if let Some(worst) = w.peek() {
420                    if best.dist >= worst.dist {
421                        break;
422                    }
423                }
424            }
425            let Reverse(current) = frontier.pop().unwrap();
426
427            for &nb in self.get_neighbors(current.id) {
428                if nb == PAD_U32 {
429                    continue;
430                }
431                if !visited.insert(nb) {
432                    continue;
433                }
434
435                let d = self.distance_to(query, nb as usize);
436                let cand = Candidate { dist: d, id: nb };
437
438                if w.len() < beam_width {
439                    w.push(cand);
440                    frontier.push(Reverse(cand));
441                } else if d < w.peek().unwrap().dist {
442                    w.pop();
443                    w.push(cand);
444                    frontier.push(Reverse(cand));
445                }
446            }
447        }
448
449        // top-k by distance, keep distances
450        let mut results: Vec<_> = w.into_vec();
451        results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
452        results.truncate(k);
453        results.into_iter().map(|c| (c.id, c.dist)).collect()
454    }
455    /// search but only return neighbor ids
456    pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
457        self.search_with_dists(query, k, beam_width)
458            .into_iter()
459            .map(|(id, _dist)| id)
460            .collect()
461    }
462
463    /// Gets the neighbors of a node from the (fixed-degree) adjacency region
464    fn get_neighbors(&self, node_id: u32) -> &[u32] {
465        let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
466        let start = offset as usize;
467        let end = start + (self.max_degree * 4);
468        let bytes = &self.mmap[start..end];
469        bytemuck::cast_slice(bytes)
470    }
471
472    /// Computes distance between `query` and vector `idx`
473    fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
474        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
475        let start = offset as usize;
476        let end = start + (self.dim * 4);
477        let bytes = &self.mmap[start..end];
478        let vector: &[f32] = bytemuck::cast_slice(bytes);
479        self.dist.eval(query, vector)
480    }
481
482    /// Gets a vector from the index (useful for tests)
483    pub fn get_vector(&self, idx: usize) -> Vec<f32> {
484        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
485        let start = offset as usize;
486        let end = start + (self.dim * 4);
487        let bytes = &self.mmap[start..end];
488        let vector: &[f32] = bytemuck::cast_slice(bytes);
489        vector.to_vec()
490    }
491}
492
493/// Calculates the medoid (vector closest to the centroid) using distance `D`
494/// Parallelizes the per-vector distance evaluations.
495fn calculate_medoid<D: Distance<f32> + Copy + Sync>(vectors: &[Vec<f32>], dist: D) -> usize {
496    let dim = vectors[0].len();
497    let mut centroid = vec![0.0f32; dim];
498
499    for v in vectors {
500        for (i, &val) in v.iter().enumerate() {
501            centroid[i] += val;
502        }
503    }
504    for val in &mut centroid {
505        *val /= vectors.len() as f32;
506    }
507
508    let (best_idx, _best_dist) = vectors
509        .par_iter()
510        .enumerate()
511        .map(|(idx, v)| (idx, dist.eval(&centroid, v)))
512        .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
513
514    best_idx
515}
516
517/// Builds a strengthened Vamana-like graph using multi-pass refinement.
518/// - Multi-seed candidate gathering (medoid + random seeds)
519/// - Union with current adjacency before α-prune
520/// - 2 refinement passes with symmetrization after each pass
521fn build_vamana_graph<D: Distance<f32> + Copy + Sync>(
522    vectors: &[Vec<f32>],
523    max_degree: usize,
524    build_beam_width: usize,
525    alpha: f32,
526    dist: D,
527    medoid_id: u32,
528) -> Vec<Vec<u32>> {
529    let n = vectors.len();
530    let mut graph = vec![Vec::<u32>::new(); n];
531
532    // Light random bootstrap to avoid disconnected starts
533    {
534        let mut rng = thread_rng();
535        for i in 0..n {
536            let mut s = HashSet::new();
537            let target = (max_degree / 2).max(2).min(n.saturating_sub(1));
538            while s.len() < target {
539                let nb = rng.gen_range(0..n);
540                if nb != i {
541                    s.insert(nb as u32);
542                }
543            }
544            graph[i] = s.into_iter().collect();
545        }
546    }
547
548    // Refinement passes
549    const PASSES: usize = 2;
550    const EXTRA_SEEDS: usize = 2;
551
552    let mut rng = thread_rng();
553    // The number of passes is key to the quality of the final graph, more passes can lead to better refinement but also increase computation time.
554    for _pass in 0..PASSES {
555        // Shuffle visit order each pass
556        let mut order: Vec<usize> = (0..n).collect();
557        order.shuffle(&mut rng);
558
559        // Snapshot read of graph for parallel candidate building
560        let snapshot = &graph;
561
562        // Build new neighbor proposals in parallel
563        let new_graph: Vec<Vec<u32>> = order
564            .par_iter()
565            .map(|&u| {
566                // the EXTRA_SEEDS random seeds help explore the graph better during construction, but at a cost of more computation.
567                let mut candidates: Vec<(u32, f32)> =
568                    Vec::with_capacity(build_beam_width * (2 + EXTRA_SEEDS));
569
570                // Include current adjacency with distances
571                for &nb in &snapshot[u] {
572                    let d = dist.eval(&vectors[u], &vectors[nb as usize]);
573                    candidates.push((nb, d));
574                }
575
576                // Seeds: always medoid + some random starts
577                let mut seeds = Vec::with_capacity(1 + EXTRA_SEEDS);
578                seeds.push(medoid_id as usize);
579                let mut trng = thread_rng();
580                for _ in 0..EXTRA_SEEDS {
581                    seeds.push(trng.gen_range(0..n));
582                }
583
584                // Gather candidates from greedy searches
585                for start in seeds {
586                    let mut part = greedy_search(
587                        &vectors[u],
588                        vectors,
589                        snapshot,
590                        start,
591                        build_beam_width,
592                        dist,
593                    );
594                    candidates.append(&mut part);
595                }
596
597                // Deduplicate by id keeping best distance
598                candidates.sort_by(|a, b| a.0.cmp(&b.0));
599                candidates.dedup_by(|a, b| {
600                    if a.0 == b.0 {
601                        if a.1 < b.1 {
602                            *b = *a;
603                        }
604                        true
605                    } else {
606                        false
607                    }
608                });
609
610                // α-prune around u
611                prune_neighbors(u, &candidates, vectors, max_degree, alpha, dist)
612            })
613            .collect();
614
615        // Symmetrize: union incoming + outgoing, then α-prune again (parallel)
616        // Build inverse map: node-id -> position in `order`
617        let mut pos_of = vec![0usize; n];
618        for (pos, &u) in order.iter().enumerate() {
619            pos_of[u] = pos;
620        }
621
622        // Build incoming as CSR
623        let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
624
625        // Union + prune in parallel
626        graph = (0..n)
627            .into_par_iter()
628            .map(|u| {
629                let ng = &new_graph[pos_of[u]]; // outgoing from this pass
630                let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]]; // incoming to u
631
632                // pool = union(outgoing ∪ incoming) with tiny, cache-friendly ops
633                let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
634                pool_ids.extend_from_slice(ng);
635                pool_ids.extend_from_slice(inc);
636                pool_ids.sort_unstable();
637                pool_ids.dedup();
638
639                // compute distances once, then α-prune
640                let pool: Vec<(u32, f32)> = pool_ids
641                    .into_iter()
642                    .filter(|&id| id as usize != u)
643                    .map(|id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
644                    .collect();
645
646                prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
647            })
648            .collect();
649    }
650
651    // Final cleanup (ensure <= max_degree everywhere)
652    graph
653        .into_par_iter()
654        .enumerate()
655        .map(|(u, neigh)| {
656            if neigh.len() <= max_degree {
657                return neigh;
658            }
659            let pool: Vec<(u32, f32)> = neigh
660                .iter()
661                .map(|&id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
662                .collect();
663            prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
664        })
665        .collect()
666}
667
668/// Greedy search used during construction (read-only on `graph`)
669/// Same termination rule as query-time search.
670fn greedy_search<D: Distance<f32> + Copy>(
671    query: &[f32],
672    vectors: &[Vec<f32>],
673    graph: &[Vec<u32>],
674    start_id: usize,
675    beam_width: usize,
676    dist: D,
677) -> Vec<(u32, f32)> {
678    let mut visited = HashSet::new();
679    let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); // min-heap by dist
680    let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); // max-heap by dist
681
682    let start_dist = dist.eval(query, &vectors[start_id]);
683    let start = Candidate {
684        dist: start_dist,
685        id: start_id as u32,
686    };
687    frontier.push(Reverse(start));
688    w.push(start);
689    visited.insert(start_id as u32);
690
691    while let Some(Reverse(best)) = frontier.peek().copied() {
692        if w.len() >= beam_width {
693            if let Some(worst) = w.peek() {
694                if best.dist >= worst.dist {
695                    break;
696                }
697            }
698        }
699        let Reverse(cur) = frontier.pop().unwrap();
700
701        for &nb in &graph[cur.id as usize] {
702            if !visited.insert(nb) {
703                continue;
704            }
705            let d = dist.eval(query, &vectors[nb as usize]);
706            let cand = Candidate { dist: d, id: nb };
707
708            if w.len() < beam_width {
709                w.push(cand);
710                frontier.push(Reverse(cand));
711            } else if d < w.peek().unwrap().dist {
712                w.pop();
713                w.push(cand);
714                frontier.push(Reverse(cand));
715            }
716        }
717    }
718
719    let mut v = w.into_vec();
720    v.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
721    v.into_iter().map(|c| (c.id, c.dist)).collect()
722}
723
724/// α-pruning from DiskANN/Vamana
725fn prune_neighbors<D: Distance<f32> + Copy>(
726    node_id: usize,
727    candidates: &[(u32, f32)],
728    vectors: &[Vec<f32>],
729    max_degree: usize,
730    alpha: f32,
731    dist: D,
732) -> Vec<u32> {
733    if candidates.is_empty() {
734        return Vec::new();
735    }
736
737    let mut sorted = candidates.to_vec();
738    sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
739
740    let mut pruned = Vec::<u32>::new();
741
742    for &(cand_id, cand_dist) in &sorted {
743        if cand_id as usize == node_id {
744            continue;
745        }
746        let mut ok = true;
747        for &sel in &pruned {
748            let d = dist.eval(&vectors[cand_id as usize], &vectors[sel as usize]);
749            if d < alpha * cand_dist {
750                ok = false;
751                break;
752            }
753        }
754        if ok {
755            pruned.push(cand_id);
756            if pruned.len() >= max_degree {
757                break;
758            }
759        }
760    }
761
762    // fill with closest if still not full
763    for &(cand_id, _) in &sorted {
764        if cand_id as usize == node_id {
765            continue;
766        }
767        if !pruned.contains(&cand_id) {
768            pruned.push(cand_id);
769            if pruned.len() >= max_degree {
770                break;
771            }
772        }
773    }
774
775    pruned
776}
777
778fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
779    // 1) count in-degree per node
780    let mut indeg = vec![0usize; n];
781    for (pos, _u) in order.iter().enumerate() {
782        for &v in &new_graph[pos] {
783            indeg[v as usize] += 1;
784        }
785    }
786    // 2) prefix sums → offsets
787    let mut off = vec![0usize; n + 1];
788    for i in 0..n {
789        off[i + 1] = off[i] + indeg[i];
790    }
791    // 3) fill flat incoming list
792    let mut cur = off.clone();
793    let mut incoming_flat = vec![0u32; off[n]];
794    for (pos, &u) in order.iter().enumerate() {
795        for &v in &new_graph[pos] {
796            let idx = cur[v as usize];
797            incoming_flat[idx] = u as u32;
798            cur[v as usize] += 1;
799        }
800    }
801    (incoming_flat, off)
802}
803
804#[cfg(test)]
805mod tests {
806    use super::*;
807    use anndists::dist::{DistCosine, DistL2};
808    use rand::Rng;
809    use std::fs;
810
811    fn euclid(a: &[f32], b: &[f32]) -> f32 {
812        a.iter()
813            .zip(b)
814            .map(|(x, y)| (x - y) * (x - y))
815            .sum::<f32>()
816            .sqrt()
817    }
818
819    #[test]
820    fn test_small_index_l2() {
821        let path = "test_small_l2.db";
822        let _ = fs::remove_file(path);
823
824        let vectors = vec![
825            vec![0.0, 0.0],
826            vec![1.0, 0.0],
827            vec![0.0, 1.0],
828            vec![1.0, 1.0],
829            vec![0.5, 0.5],
830        ];
831
832        let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
833
834        let q = vec![0.1, 0.1];
835        let nns = index.search(&q, 3, 8);
836        assert_eq!(nns.len(), 3);
837
838        // Verify the first neighbor is quite close in L2
839        let v = index.get_vector(nns[0] as usize);
840        assert!(euclid(&q, &v) < 1.0);
841
842        let _ = fs::remove_file(path);
843    }
844
845    #[test]
846    fn test_cosine() {
847        let path = "test_cosine.db";
848        let _ = fs::remove_file(path);
849
850        let vectors = vec![
851            vec![1.0, 0.0, 0.0],
852            vec![0.0, 1.0, 0.0],
853            vec![0.0, 0.0, 1.0],
854            vec![1.0, 1.0, 0.0],
855            vec![1.0, 0.0, 1.0],
856        ];
857
858        let index =
859            DiskANN::<DistCosine>::build_index_default(&vectors, DistCosine {}, path).unwrap();
860
861        let q = vec![2.0, 0.0, 0.0]; // parallel to [1,0,0]
862        let nns = index.search(&q, 2, 8);
863        assert_eq!(nns.len(), 2);
864
865        // Top neighbor should have high cosine similarity (close direction)
866        let v = index.get_vector(nns[0] as usize);
867        let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
868        let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
869        let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
870        let cos = dot / (n1 * n2);
871        assert!(cos > 0.7);
872
873        let _ = fs::remove_file(path);
874    }
875
876    #[test]
877    fn test_persistence_and_open() {
878        let path = "test_persist.db";
879        let _ = fs::remove_file(path);
880
881        let vectors = vec![
882            vec![0.0, 0.0],
883            vec![1.0, 0.0],
884            vec![0.0, 1.0],
885            vec![1.0, 1.0],
886        ];
887
888        {
889            let _idx = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
890        }
891
892        let idx2 = DiskANN::<DistL2>::open_index_default_metric(path).unwrap();
893        assert_eq!(idx2.num_vectors, 4);
894        assert_eq!(idx2.dim, 2);
895
896        let q = vec![0.9, 0.9];
897        let res = idx2.search(&q, 2, 8);
898        // [1,1] should be best
899        assert_eq!(res[0], 3);
900
901        let _ = fs::remove_file(path);
902    }
903
904    #[test]
905    fn test_grid_connectivity() {
906        let path = "test_grid.db";
907        let _ = fs::remove_file(path);
908
909        // 5x5 grid
910        let mut vectors = Vec::new();
911        for i in 0..5 {
912            for j in 0..5 {
913                vectors.push(vec![i as f32, j as f32]);
914            }
915        }
916
917        let index = DiskANN::<DistL2>::build_index_with_params(
918            &vectors,
919            DistL2 {},
920            path,
921            DiskAnnParams {
922                max_degree: 4,
923                build_beam_width: 64,
924                alpha: 1.5,
925            },
926        )
927        .unwrap();
928
929        for target in 0..vectors.len() {
930            let q = &vectors[target];
931            let nns = index.search(q, 10, 32);
932            if !nns.contains(&(target as u32)) {
933                let v = index.get_vector(nns[0] as usize);
934                assert!(euclid(q, &v) < 2.0);
935            }
936            for &nb in nns.iter().take(5) {
937                let v = index.get_vector(nb as usize);
938                assert!(euclid(q, &v) < 5.0);
939            }
940        }
941
942        let _ = fs::remove_file(path);
943    }
944
945    #[test]
946    fn test_medium_random() {
947        let path = "test_medium.db";
948        let _ = fs::remove_file(path);
949
950        let n = 200usize;
951        let d = 32usize;
952        let mut rng = rand::thread_rng();
953        let vectors: Vec<Vec<f32>> = (0..n)
954            .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
955            .collect();
956
957        let index = DiskANN::<DistL2>::build_index_with_params(
958            &vectors,
959            DistL2 {},
960            path,
961            DiskAnnParams {
962                max_degree: 32,
963                build_beam_width: 128,
964                alpha: 1.2,
965            },
966        )
967        .unwrap();
968
969        let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
970        let res = index.search(&q, 10, 64);
971        assert_eq!(res.len(), 10);
972
973        // Ensure distances are nondecreasing
974        let dists: Vec<f32> = res
975            .iter()
976            .map(|&id| {
977                let v = index.get_vector(id as usize);
978                euclid(&q, &v)
979            })
980            .collect();
981        let mut sorted = dists.clone();
982        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
983        assert_eq!(dists, sorted);
984
985        let _ = fs::remove_file(path);
986    }
987}