Skip to main content

ruvector_core/advanced_features/
diskann.rs

1//! DiskANN / Vamana SSD-Backed Approximate Nearest Neighbor Index
2//!
3//! Implements the Vamana graph index from the DiskANN paper (Subramanya et al., 2019).
4//! Each node connects to R neighbors chosen via **alpha-RNG pruning** -- a relaxed
5//! Relative Neighborhood Graph balancing proximity and angular diversity.
6//!
7//! # Why DiskANN achieves 95%+ recall at sub-10ms
8//!
9//! - **Vamana graph**: alpha > 1.0 retains long-range shortcuts for O(log n) hops.
10//! - **SSD layout**: node vector + neighbors packed in aligned pages; one read per hop.
11//! - **Page cache**: LRU cache keeps hot pages in memory (80-95% hit rates typical).
12//! - **Filtered traversal**: predicates evaluated during search, not post-filter.
13//!
14//! # Alpha-RNG Pruning
15//!
16//! A candidate c is kept only if for every already-selected neighbor n,
17//! `dist(p, c) <= alpha * dist(n, c)`, ensuring angular diversity.
18
19use crate::error::{Result, RuvectorError};
20use serde::{Deserialize, Serialize};
21use std::cmp::Reverse;
22use std::collections::{BinaryHeap, HashMap, HashSet};
23
24/// Configuration for the Vamana graph index.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct VamanaConfig {
27    /// Maximum out-degree per node (R). Typical: 32-64.
28    pub max_degree: usize,
29    /// Search list size (L). Larger = better recall, slower search.
30    pub search_list_size: usize,
31    /// Pruning parameter (>= 1.0). Typical: 1.2.
32    pub alpha: f32,
33    /// Thread count for build (reserved for future parallel builds).
34    pub num_build_threads: usize,
35    /// Page size for SSD-aligned layout in bytes.
36    pub ssd_page_size: usize,
37}
38
39impl Default for VamanaConfig {
40    fn default() -> Self {
41        Self {
42            max_degree: 32,
43            search_list_size: 64,
44            alpha: 1.2,
45            num_build_threads: 1,
46            ssd_page_size: 4096,
47        }
48    }
49}
50
51impl VamanaConfig {
52    /// Validate configuration parameters.
53    pub fn validate(&self) -> Result<()> {
54        if self.max_degree == 0 {
55            return Err(RuvectorError::InvalidParameter(
56                "max_degree must be > 0".into(),
57            ));
58        }
59        if self.search_list_size < 1 {
60            return Err(RuvectorError::InvalidParameter(
61                "search_list_size must be >= 1".into(),
62            ));
63        }
64        if self.alpha < 1.0 {
65            return Err(RuvectorError::InvalidParameter(
66                "alpha must be >= 1.0".into(),
67            ));
68        }
69        Ok(())
70    }
71}
72
73/// In-memory Vamana graph for building and searching.
74#[derive(Debug, Clone)]
75pub struct VamanaGraph {
76    /// Adjacency lists per node.
77    pub neighbors: Vec<Vec<u32>>,
78    /// Vectors, row-major.
79    pub vectors: Vec<Vec<f32>>,
80    /// Medoid (entry point) index.
81    pub medoid: u32,
82    /// Build config.
83    pub config: VamanaConfig,
84}
85
86impl VamanaGraph {
87    /// Build a Vamana graph: find medoid, init neighbors, then refine via greedy search + robust prune.
88    pub fn build(vectors: Vec<Vec<f32>>, config: VamanaConfig) -> Result<Self> {
89        config.validate()?;
90        let n = vectors.len();
91        if n == 0 {
92            return Ok(Self {
93                neighbors: vec![],
94                vectors: vec![],
95                medoid: 0,
96                config,
97            });
98        }
99        let dim = vectors[0].len();
100        for v in &vectors {
101            if v.len() != dim {
102                return Err(RuvectorError::DimensionMismatch {
103                    expected: dim,
104                    actual: v.len(),
105                });
106            }
107        }
108        let medoid = MedoidFinder::find_medoid(&vectors);
109        let mut graph = Self {
110            neighbors: vec![vec![]; n],
111            vectors,
112            medoid,
113            config,
114        };
115        // Initialize with sequential neighbors.
116        for i in 0..n {
117            let mut nb = Vec::new();
118            for j in 0..n.min(graph.config.max_degree + 1) {
119                if j != i {
120                    nb.push(j as u32);
121                }
122                if nb.len() >= graph.config.max_degree {
123                    break;
124                }
125            }
126            graph.neighbors[i] = nb;
127        }
128        // Refine: search, prune, add reverse edges.
129        for i in 0..n {
130            let query = graph.vectors[i].clone();
131            let (cands, _) = graph.greedy_search_internal(&query, graph.config.search_list_size);
132            let mut cset: Vec<u32> = cands.into_iter().filter(|&c| c != i as u32).collect();
133            for &nb in &graph.neighbors[i] {
134                if !cset.contains(&nb) {
135                    cset.push(nb);
136                }
137            }
138            let pruned = graph.robust_prune(i as u32, &cset);
139            graph.neighbors[i] = pruned.clone();
140            for &nb in &pruned {
141                let ni = nb as usize;
142                if !graph.neighbors[ni].contains(&(i as u32)) {
143                    graph.neighbors[ni].push(i as u32);
144                    if graph.neighbors[ni].len() > graph.config.max_degree {
145                        let nbs = graph.neighbors[ni].clone();
146                        graph.neighbors[ni] = graph.robust_prune(nb, &nbs);
147                    }
148                }
149            }
150        }
151        Ok(graph)
152    }
153
154    /// Greedy beam search returning top_k (node_id, distance) pairs.
155    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(u32, f32)> {
156        if self.vectors.is_empty() {
157            return vec![];
158        }
159        let beam = self.config.search_list_size.max(top_k);
160        let (ids, dists) = self.greedy_search_internal(query, beam);
161        ids.into_iter().zip(dists).take(top_k).collect()
162    }
163
164    fn greedy_search_internal(&self, query: &[f32], list_size: usize) -> (Vec<u32>, Vec<f32>) {
165        let mut visited = HashSet::new();
166        let mut frontier: BinaryHeap<Reverse<OrdF32Pair>> = BinaryHeap::new();
167        let mut results: Vec<(f32, u32)> = Vec::new();
168        let start = self.medoid;
169        let d = l2_sq(&self.vectors[start as usize], query);
170        frontier.push(Reverse(OrdF32Pair(d, start)));
171        visited.insert(start);
172        results.push((d, start));
173        while let Some(Reverse(OrdF32Pair(_, node))) = frontier.pop() {
174            for &nb in &self.neighbors[node as usize] {
175                if visited.insert(nb) {
176                    let dist = l2_sq(&self.vectors[nb as usize], query);
177                    results.push((dist, nb));
178                    frontier.push(Reverse(OrdF32Pair(dist, nb)));
179                }
180            }
181            if results.len() > list_size * 2 {
182                results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
183                results.truncate(list_size);
184            }
185        }
186        results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
187        results.truncate(list_size);
188        (
189            results.iter().map(|r| r.1).collect(),
190            results.iter().map(|r| r.0).collect(),
191        )
192    }
193
194    /// Robust prune: greedily select diverse neighbors via the alpha-RNG rule.
195    fn robust_prune(&self, node_id: u32, candidates: &[u32]) -> Vec<u32> {
196        let nv = &self.vectors[node_id as usize];
197        let mut scored: Vec<(f32, u32)> = candidates
198            .iter()
199            .filter(|&&c| c != node_id)
200            .map(|&c| (l2_sq(nv, &self.vectors[c as usize]), c))
201            .collect();
202        scored.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
203        let mut sel: Vec<u32> = Vec::new();
204        for (d2n, cand) in scored {
205            if sel.len() >= self.config.max_degree {
206                break;
207            }
208            let cv = &self.vectors[cand as usize];
209            if sel
210                .iter()
211                .all(|&s| d2n <= self.config.alpha * l2_sq(&self.vectors[s as usize], cv))
212            {
213                sel.push(cand);
214            }
215        }
216        sel
217    }
218}
219
220/// A node stored in SSD-backed layout: id + neighbors + vector in one page.
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct DiskNode {
223    pub node_id: u32,
224    pub neighbors: Vec<u32>,
225    pub vector: Vec<f32>,
226}
227
228/// IO statistics for disk-based search.
229#[derive(Debug, Clone, Default)]
230pub struct IOStats {
231    pub pages_read: usize,
232    pub bytes_read: usize,
233    pub cache_hits: usize,
234}
235
236/// Simulated SSD-backed index with page-aligned reads and LRU cache.
237#[derive(Debug)]
238pub struct DiskIndex {
239    nodes: Vec<DiskNode>,
240    page_size: usize,
241    medoid: u32,
242    cache: PageCache,
243}
244
245impl DiskIndex {
246    /// Create from a built VamanaGraph.
247    pub fn from_graph(graph: &VamanaGraph, cache_size_pages: usize) -> Self {
248        let nodes = (0..graph.vectors.len())
249            .map(|i| DiskNode {
250                node_id: i as u32,
251                neighbors: graph.neighbors[i].clone(),
252                vector: graph.vectors[i].clone(),
253            })
254            .collect();
255        Self {
256            nodes,
257            page_size: graph.config.ssd_page_size,
258            medoid: graph.medoid,
259            cache: PageCache::new(cache_size_pages),
260        }
261    }
262
263    /// Beam search with IO accounting.
264    pub fn search_disk(
265        &mut self,
266        query: &[f32],
267        top_k: usize,
268        beam_width: usize,
269    ) -> (Vec<(u32, f32)>, IOStats) {
270        let mut stats = IOStats::default();
271        if self.nodes.is_empty() {
272            return (vec![], stats);
273        }
274        let mut visited = HashSet::new();
275        let mut frontier: BinaryHeap<Reverse<OrdF32Pair>> = BinaryHeap::new();
276        let mut results: Vec<(f32, u32)> = Vec::new();
277        let start = self.medoid;
278        let d = l2_sq(&self.read_node(start, &mut stats).vector.clone(), query);
279        frontier.push(Reverse(OrdF32Pair(d, start)));
280        visited.insert(start);
281        results.push((d, start));
282        while let Some(Reverse(OrdF32Pair(_, cur))) = frontier.pop() {
283            let nbs = self.read_node(cur, &mut stats).neighbors.clone();
284            for nb in nbs {
285                if visited.insert(nb) {
286                    let v = self.read_node(nb, &mut stats).vector.clone();
287                    let dist = l2_sq(&v, query);
288                    results.push((dist, nb));
289                    frontier.push(Reverse(OrdF32Pair(dist, nb)));
290                }
291            }
292            if results.len() > beam_width * 2 {
293                results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
294                results.truncate(beam_width);
295            }
296        }
297        results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
298        results.truncate(top_k);
299        (results.iter().map(|r| (r.1, r.0)).collect(), stats)
300    }
301
302    fn read_node(&mut self, node_id: u32, stats: &mut IOStats) -> &DiskNode {
303        let page_id = node_id as usize;
304        if self.cache.get(page_id) {
305            stats.cache_hits += 1;
306        } else {
307            stats.pages_read += 1;
308            stats.bytes_read += self.page_size;
309            self.cache.insert(page_id);
310        }
311        &self.nodes[node_id as usize]
312    }
313
314    /// Filtered search: predicates evaluated during traversal (not post-filter).
315    /// Ineligible nodes still expand the frontier to preserve graph connectivity.
316    pub fn search_with_filter<F>(
317        &mut self,
318        query: &[f32],
319        filter_fn: F,
320        top_k: usize,
321    ) -> Vec<(u32, f32)>
322    where
323        F: Fn(u32) -> bool,
324    {
325        if self.nodes.is_empty() {
326            return vec![];
327        }
328        let mut visited = HashSet::new();
329        let mut frontier: BinaryHeap<Reverse<OrdF32Pair>> = BinaryHeap::new();
330        let mut results: Vec<(f32, u32)> = Vec::new();
331        let mut io = IOStats::default();
332        let start = self.medoid;
333        let d = l2_sq(&self.read_node(start, &mut io).vector.clone(), query);
334        frontier.push(Reverse(OrdF32Pair(d, start)));
335        visited.insert(start);
336        if filter_fn(start) {
337            results.push((d, start));
338        }
339        while let Some(Reverse(OrdF32Pair(_, cur))) = frontier.pop() {
340            let nbs = self.read_node(cur, &mut io).neighbors.clone();
341            for nb in nbs {
342                if visited.insert(nb) {
343                    let v = self.read_node(nb, &mut io).vector.clone();
344                    let dist = l2_sq(&v, query);
345                    frontier.push(Reverse(OrdF32Pair(dist, nb)));
346                    if filter_fn(nb) {
347                        results.push((dist, nb));
348                    }
349                }
350            }
351        }
352        results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
353        results.truncate(top_k);
354        results.iter().map(|r| (r.1, r.0)).collect()
355    }
356}
357
358/// LRU page cache tracking access recency via a clock counter.
359#[derive(Debug)]
360pub struct PageCache {
361    capacity: usize,
362    clock: u64,
363    entries: HashMap<usize, u64>,
364    total_hits: u64,
365    total_accesses: u64,
366}
367
368impl PageCache {
369    pub fn new(capacity: usize) -> Self {
370        Self {
371            capacity,
372            clock: 0,
373            entries: HashMap::new(),
374            total_hits: 0,
375            total_accesses: 0,
376        }
377    }
378
379    /// Returns true on cache hit, updating recency.
380    pub fn get(&mut self, page_id: usize) -> bool {
381        self.total_accesses += 1;
382        self.clock += 1;
383        if let Some(ts) = self.entries.get_mut(&page_id) {
384            *ts = self.clock;
385            self.total_hits += 1;
386            true
387        } else {
388            false
389        }
390    }
391
392    /// Insert a page, evicting LRU if at capacity.
393    pub fn insert(&mut self, page_id: usize) {
394        if self.capacity == 0 {
395            return;
396        }
397        if self.entries.len() >= self.capacity {
398            let lru = self
399                .entries
400                .iter()
401                .min_by_key(|&(_, ts)| *ts)
402                .map(|(&k, _)| k);
403            if let Some(k) = lru {
404                self.entries.remove(&k);
405            }
406        }
407        self.clock += 1;
408        self.entries.insert(page_id, self.clock);
409    }
410
411    /// Cache hit rate in [0.0, 1.0].
412    pub fn cache_hit_rate(&self) -> f64 {
413        if self.total_accesses == 0 {
414            0.0
415        } else {
416            self.total_hits as f64 / self.total_accesses as f64
417        }
418    }
419}
420
421/// Finds the geometric medoid (point minimising sum of distances to all others).
422pub struct MedoidFinder;
423
424impl MedoidFinder {
425    pub fn find_medoid(vectors: &[Vec<f32>]) -> u32 {
426        if vectors.is_empty() {
427            return 0;
428        }
429        let (mut best_idx, mut best_sum) = (0u32, f32::MAX);
430        for i in 0..vectors.len() {
431            let sum: f32 = (0..vectors.len())
432                .map(|j| l2_sq(&vectors[i], &vectors[j]))
433                .sum();
434            if sum < best_sum {
435                best_sum = sum;
436                best_idx = i as u32;
437            }
438        }
439        best_idx
440    }
441}
442
443/// L2 squared distance.
444fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
445    a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum()
446}
447
448#[derive(Debug, Clone, PartialEq)]
449struct OrdF32Pair(f32, u32);
450impl Eq for OrdF32Pair {}
451impl PartialOrd for OrdF32Pair {
452    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
453        Some(self.cmp(other))
454    }
455}
456impl Ord for OrdF32Pair {
457    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
458        self.0
459            .partial_cmp(&other.0)
460            .unwrap_or(std::cmp::Ordering::Equal)
461            .then(self.1.cmp(&other.1))
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    fn make_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
470        (0..n)
471            .map(|i| (0..dim).map(|d| (i * dim + d) as f32).collect())
472            .collect()
473    }
474    fn default_cfg(r: usize, l: usize) -> VamanaConfig {
475        VamanaConfig {
476            max_degree: r,
477            search_list_size: l,
478            ..Default::default()
479        }
480    }
481
482    #[test]
483    fn build_graph_basic() {
484        let g = VamanaGraph::build(make_vecs(10, 4), default_cfg(4, 8)).unwrap();
485        assert_eq!(g.vectors.len(), 10);
486        for nb in &g.neighbors {
487            assert!(nb.len() <= 4);
488        }
489    }
490
491    #[test]
492    fn search_accuracy() {
493        let mut v = make_vecs(20, 4);
494        v.push(vec![0.1, 0.1, 0.1, 0.1]);
495        let g = VamanaGraph::build(v, default_cfg(8, 30)).unwrap();
496        let r = g.search(&[0.0; 4], 3);
497        assert!(r.iter().any(|&(id, _)| id == 20));
498    }
499
500    #[test]
501    fn robust_pruning_limits_degree() {
502        let g = VamanaGraph::build(make_vecs(50, 4), default_cfg(5, 16)).unwrap();
503        for nb in &g.neighbors {
504            assert!(nb.len() <= 5);
505        }
506    }
507
508    #[test]
509    fn disk_layout_roundtrip() {
510        let v = make_vecs(10, 4);
511        let g = VamanaGraph::build(v.clone(), VamanaConfig::default()).unwrap();
512        let d = DiskIndex::from_graph(&g, 16);
513        for i in 0..10 {
514            assert_eq!(d.nodes[i].node_id, i as u32);
515            assert_eq!(d.nodes[i].vector, v[i]);
516            assert_eq!(d.nodes[i].neighbors, g.neighbors[i]);
517        }
518    }
519
520    #[test]
521    fn page_cache_hits_and_misses() {
522        let mut c = PageCache::new(2);
523        assert!(!c.get(0));
524        c.insert(0);
525        assert!(c.get(0));
526        c.insert(1);
527        c.insert(2); // evicts 0
528        assert!(!c.get(0));
529        assert!(c.get(1));
530    }
531
532    #[test]
533    fn cache_hit_rate() {
534        let mut c = PageCache::new(4);
535        c.insert(0);
536        c.insert(1);
537        assert!(c.get(0));
538        assert!(c.get(1));
539        assert!(!c.get(2));
540        assert!((c.cache_hit_rate() - 2.0 / 3.0).abs() < 1e-6);
541    }
542
543    #[test]
544    fn filtered_search() {
545        let mut v = make_vecs(15, 4);
546        v.push(vec![0.1; 4]);
547        let g = VamanaGraph::build(v, default_cfg(8, 20)).unwrap();
548        let mut d = DiskIndex::from_graph(&g, 32);
549        let r = d.search_with_filter(&[0.0; 4], |id| id % 2 == 0, 5);
550        for &(id, _) in &r {
551            assert_eq!(id % 2, 0);
552        }
553    }
554
555    #[test]
556    fn medoid_selection() {
557        let v = vec![
558            vec![0.0, 0.0],
559            vec![1.0, 0.0],
560            vec![0.0, 1.0],
561            vec![0.5, 0.5],
562        ];
563        assert_eq!(MedoidFinder::find_medoid(&v), 3);
564    }
565
566    #[test]
567    fn empty_dataset() {
568        let g = VamanaGraph::build(vec![], VamanaConfig::default()).unwrap();
569        assert!(g.vectors.is_empty());
570        assert!(g.search(&[1.0, 2.0], 5).is_empty());
571    }
572
573    #[test]
574    fn single_vector() {
575        let g = VamanaGraph::build(vec![vec![1.0, 2.0, 3.0]], VamanaConfig::default()).unwrap();
576        assert!(g.neighbors[0].is_empty());
577        let r = g.search(&[1.0, 2.0, 3.0], 1);
578        assert_eq!(r.len(), 1);
579        assert_eq!(r[0].0, 0);
580    }
581
582    #[test]
583    fn io_stats_tracking() {
584        let g = VamanaGraph::build(make_vecs(10, 4), default_cfg(4, 10)).unwrap();
585        let mut d = DiskIndex::from_graph(&g, 2);
586        let (_, s) = d.search_disk(&[0.0; 4], 3, 10);
587        assert!(s.pages_read > 0);
588        assert_eq!(s.bytes_read, s.pages_read * 4096);
589    }
590
591    #[test]
592    fn disk_search_sorted_results() {
593        let g = VamanaGraph::build(make_vecs(20, 4), default_cfg(8, 20)).unwrap();
594        let mut d = DiskIndex::from_graph(&g, 32);
595        let (r, s) = d.search_disk(&[0.0; 4], 5, 20);
596        assert_eq!(r.len(), 5);
597        for w in r.windows(2) {
598            assert!(w[0].1 <= w[1].1);
599        }
600        assert!(s.pages_read + s.cache_hits > 0);
601    }
602
603    #[test]
604    fn config_validation() {
605        assert!(VamanaConfig {
606            max_degree: 0,
607            ..Default::default()
608        }
609        .validate()
610        .is_err());
611        assert!(VamanaConfig {
612            alpha: 0.5,
613            ..Default::default()
614        }
615        .validate()
616        .is_err());
617        assert!(VamanaConfig::default().validate().is_ok());
618    }
619}