Skip to main content

ruvector_core/advanced_features/
diskann.rs

1//! DiskANN / Vamana SSD-Backed Approximate Nearest Neighbor Index
2//!
3//! Implements the Vamana graph index from the DiskANN paper (Subramanya et al., 2019).
4//! Each node connects to R neighbors chosen via **alpha-RNG pruning** -- a relaxed
5//! Relative Neighborhood Graph balancing proximity and angular diversity.
6//!
7//! # Why DiskANN achieves 95%+ recall at sub-10ms
8//!
9//! - **Vamana graph**: alpha > 1.0 retains long-range shortcuts for O(log n) hops.
10//! - **SSD layout**: node vector + neighbors packed in aligned pages; one read per hop.
11//! - **Page cache**: LRU cache keeps hot pages in memory (80-95% hit rates typical).
12//! - **Filtered traversal**: predicates evaluated during search, not post-filter.
13//!
14//! # Alpha-RNG Pruning
15//!
16//! A candidate c is kept only if for every already-selected neighbor n,
17//! `dist(p, c) <= alpha * dist(n, c)`, ensuring angular diversity.
18
19use crate::error::{Result, RuvectorError};
20use serde::{Deserialize, Serialize};
21use std::collections::{BinaryHeap, HashMap, HashSet};
22use std::cmp::Reverse;
23
24/// Configuration for the Vamana graph index.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct VamanaConfig {
27    /// Maximum out-degree per node (R). Typical: 32-64.
28    pub max_degree: usize,
29    /// Search list size (L). Larger = better recall, slower search.
30    pub search_list_size: usize,
31    /// Pruning parameter (>= 1.0). Typical: 1.2.
32    pub alpha: f32,
33    /// Thread count for build (reserved for future parallel builds).
34    pub num_build_threads: usize,
35    /// Page size for SSD-aligned layout in bytes.
36    pub ssd_page_size: usize,
37}
38
39impl Default for VamanaConfig {
40    fn default() -> Self {
41        Self { max_degree: 32, search_list_size: 64, alpha: 1.2, num_build_threads: 1, ssd_page_size: 4096 }
42    }
43}
44
45impl VamanaConfig {
46    /// Validate configuration parameters.
47    pub fn validate(&self) -> Result<()> {
48        if self.max_degree == 0 {
49            return Err(RuvectorError::InvalidParameter("max_degree must be > 0".into()));
50        }
51        if self.search_list_size < 1 {
52            return Err(RuvectorError::InvalidParameter("search_list_size must be >= 1".into()));
53        }
54        if self.alpha < 1.0 {
55            return Err(RuvectorError::InvalidParameter("alpha must be >= 1.0".into()));
56        }
57        Ok(())
58    }
59}
60
61/// In-memory Vamana graph for building and searching.
62#[derive(Debug, Clone)]
63pub struct VamanaGraph {
64    /// Adjacency lists per node.
65    pub neighbors: Vec<Vec<u32>>,
66    /// Vectors, row-major.
67    pub vectors: Vec<Vec<f32>>,
68    /// Medoid (entry point) index.
69    pub medoid: u32,
70    /// Build config.
71    pub config: VamanaConfig,
72}
73
74impl VamanaGraph {
75    /// Build a Vamana graph: find medoid, init neighbors, then refine via greedy search + robust prune.
76    pub fn build(vectors: Vec<Vec<f32>>, config: VamanaConfig) -> Result<Self> {
77        config.validate()?;
78        let n = vectors.len();
79        if n == 0 {
80            return Ok(Self { neighbors: vec![], vectors: vec![], medoid: 0, config });
81        }
82        let dim = vectors[0].len();
83        for v in &vectors {
84            if v.len() != dim {
85                return Err(RuvectorError::DimensionMismatch { expected: dim, actual: v.len() });
86            }
87        }
88        let medoid = MedoidFinder::find_medoid(&vectors);
89        let mut graph = Self { neighbors: vec![vec![]; n], vectors, medoid, config };
90        // Initialize with sequential neighbors.
91        for i in 0..n {
92            let mut nb = Vec::new();
93            for j in 0..n.min(graph.config.max_degree + 1) {
94                if j != i { nb.push(j as u32); }
95                if nb.len() >= graph.config.max_degree { break; }
96            }
97            graph.neighbors[i] = nb;
98        }
99        // Refine: search, prune, add reverse edges.
100        for i in 0..n {
101            let query = graph.vectors[i].clone();
102            let (cands, _) = graph.greedy_search_internal(&query, graph.config.search_list_size);
103            let mut cset: Vec<u32> = cands.into_iter().filter(|&c| c != i as u32).collect();
104            for &nb in &graph.neighbors[i] {
105                if !cset.contains(&nb) { cset.push(nb); }
106            }
107            let pruned = graph.robust_prune(i as u32, &cset);
108            graph.neighbors[i] = pruned.clone();
109            for &nb in &pruned {
110                let ni = nb as usize;
111                if !graph.neighbors[ni].contains(&(i as u32)) {
112                    graph.neighbors[ni].push(i as u32);
113                    if graph.neighbors[ni].len() > graph.config.max_degree {
114                        let nbs = graph.neighbors[ni].clone();
115                        graph.neighbors[ni] = graph.robust_prune(nb, &nbs);
116                    }
117                }
118            }
119        }
120        Ok(graph)
121    }
122
123    /// Greedy beam search returning top_k (node_id, distance) pairs.
124    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(u32, f32)> {
125        if self.vectors.is_empty() { return vec![]; }
126        let beam = self.config.search_list_size.max(top_k);
127        let (ids, dists) = self.greedy_search_internal(query, beam);
128        ids.into_iter().zip(dists).take(top_k).collect()
129    }
130
131    fn greedy_search_internal(&self, query: &[f32], list_size: usize) -> (Vec<u32>, Vec<f32>) {
132        let mut visited = HashSet::new();
133        let mut frontier: BinaryHeap<Reverse<OrdF32Pair>> = BinaryHeap::new();
134        let mut results: Vec<(f32, u32)> = Vec::new();
135        let start = self.medoid;
136        let d = l2_sq(&self.vectors[start as usize], query);
137        frontier.push(Reverse(OrdF32Pair(d, start)));
138        visited.insert(start);
139        results.push((d, start));
140        while let Some(Reverse(OrdF32Pair(_, node))) = frontier.pop() {
141            for &nb in &self.neighbors[node as usize] {
142                if visited.insert(nb) {
143                    let dist = l2_sq(&self.vectors[nb as usize], query);
144                    results.push((dist, nb));
145                    frontier.push(Reverse(OrdF32Pair(dist, nb)));
146                }
147            }
148            if results.len() > list_size * 2 {
149                results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
150                results.truncate(list_size);
151            }
152        }
153        results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
154        results.truncate(list_size);
155        (results.iter().map(|r| r.1).collect(), results.iter().map(|r| r.0).collect())
156    }
157
158    /// Robust prune: greedily select diverse neighbors via the alpha-RNG rule.
159    fn robust_prune(&self, node_id: u32, candidates: &[u32]) -> Vec<u32> {
160        let nv = &self.vectors[node_id as usize];
161        let mut scored: Vec<(f32, u32)> = candidates.iter()
162            .filter(|&&c| c != node_id)
163            .map(|&c| (l2_sq(nv, &self.vectors[c as usize]), c))
164            .collect();
165        scored.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
166        let mut sel: Vec<u32> = Vec::new();
167        for (d2n, cand) in scored {
168            if sel.len() >= self.config.max_degree { break; }
169            let cv = &self.vectors[cand as usize];
170            if sel.iter().all(|&s| d2n <= self.config.alpha * l2_sq(&self.vectors[s as usize], cv)) {
171                sel.push(cand);
172            }
173        }
174        sel
175    }
176}
177
178/// A node stored in SSD-backed layout: id + neighbors + vector in one page.
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct DiskNode {
181    pub node_id: u32,
182    pub neighbors: Vec<u32>,
183    pub vector: Vec<f32>,
184}
185
186/// IO statistics for disk-based search.
187#[derive(Debug, Clone, Default)]
188pub struct IOStats {
189    pub pages_read: usize,
190    pub bytes_read: usize,
191    pub cache_hits: usize,
192}
193
194/// Simulated SSD-backed index with page-aligned reads and LRU cache.
195#[derive(Debug)]
196pub struct DiskIndex {
197    nodes: Vec<DiskNode>,
198    page_size: usize,
199    medoid: u32,
200    cache: PageCache,
201}
202
203impl DiskIndex {
204    /// Create from a built VamanaGraph.
205    pub fn from_graph(graph: &VamanaGraph, cache_size_pages: usize) -> Self {
206        let nodes = (0..graph.vectors.len()).map(|i| DiskNode {
207            node_id: i as u32, neighbors: graph.neighbors[i].clone(), vector: graph.vectors[i].clone(),
208        }).collect();
209        Self { nodes, page_size: graph.config.ssd_page_size, medoid: graph.medoid, cache: PageCache::new(cache_size_pages) }
210    }
211
212    /// Beam search with IO accounting.
213    pub fn search_disk(&mut self, query: &[f32], top_k: usize, beam_width: usize) -> (Vec<(u32, f32)>, IOStats) {
214        let mut stats = IOStats::default();
215        if self.nodes.is_empty() { return (vec![], stats); }
216        let mut visited = HashSet::new();
217        let mut frontier: BinaryHeap<Reverse<OrdF32Pair>> = BinaryHeap::new();
218        let mut results: Vec<(f32, u32)> = Vec::new();
219        let start = self.medoid;
220        let d = l2_sq(&self.read_node(start, &mut stats).vector.clone(), query);
221        frontier.push(Reverse(OrdF32Pair(d, start)));
222        visited.insert(start);
223        results.push((d, start));
224        while let Some(Reverse(OrdF32Pair(_, cur))) = frontier.pop() {
225            let nbs = self.read_node(cur, &mut stats).neighbors.clone();
226            for nb in nbs {
227                if visited.insert(nb) {
228                    let v = self.read_node(nb, &mut stats).vector.clone();
229                    let dist = l2_sq(&v, query);
230                    results.push((dist, nb));
231                    frontier.push(Reverse(OrdF32Pair(dist, nb)));
232                }
233            }
234            if results.len() > beam_width * 2 {
235                results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
236                results.truncate(beam_width);
237            }
238        }
239        results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
240        results.truncate(top_k);
241        (results.iter().map(|r| (r.1, r.0)).collect(), stats)
242    }
243
244    fn read_node(&mut self, node_id: u32, stats: &mut IOStats) -> &DiskNode {
245        let page_id = node_id as usize;
246        if self.cache.get(page_id) { stats.cache_hits += 1; }
247        else { stats.pages_read += 1; stats.bytes_read += self.page_size; self.cache.insert(page_id); }
248        &self.nodes[node_id as usize]
249    }
250
251    /// Filtered search: predicates evaluated during traversal (not post-filter).
252    /// Ineligible nodes still expand the frontier to preserve graph connectivity.
253    pub fn search_with_filter<F>(&mut self, query: &[f32], filter_fn: F, top_k: usize) -> Vec<(u32, f32)>
254    where F: Fn(u32) -> bool {
255        if self.nodes.is_empty() { return vec![]; }
256        let mut visited = HashSet::new();
257        let mut frontier: BinaryHeap<Reverse<OrdF32Pair>> = BinaryHeap::new();
258        let mut results: Vec<(f32, u32)> = Vec::new();
259        let mut io = IOStats::default();
260        let start = self.medoid;
261        let d = l2_sq(&self.read_node(start, &mut io).vector.clone(), query);
262        frontier.push(Reverse(OrdF32Pair(d, start)));
263        visited.insert(start);
264        if filter_fn(start) { results.push((d, start)); }
265        while let Some(Reverse(OrdF32Pair(_, cur))) = frontier.pop() {
266            let nbs = self.read_node(cur, &mut io).neighbors.clone();
267            for nb in nbs {
268                if visited.insert(nb) {
269                    let v = self.read_node(nb, &mut io).vector.clone();
270                    let dist = l2_sq(&v, query);
271                    frontier.push(Reverse(OrdF32Pair(dist, nb)));
272                    if filter_fn(nb) { results.push((dist, nb)); }
273                }
274            }
275        }
276        results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
277        results.truncate(top_k);
278        results.iter().map(|r| (r.1, r.0)).collect()
279    }
280}
281
282/// LRU page cache tracking access recency via a clock counter.
283#[derive(Debug)]
284pub struct PageCache {
285    capacity: usize,
286    clock: u64,
287    entries: HashMap<usize, u64>,
288    total_hits: u64,
289    total_accesses: u64,
290}
291
292impl PageCache {
293    pub fn new(capacity: usize) -> Self {
294        Self { capacity, clock: 0, entries: HashMap::new(), total_hits: 0, total_accesses: 0 }
295    }
296
297    /// Returns true on cache hit, updating recency.
298    pub fn get(&mut self, page_id: usize) -> bool {
299        self.total_accesses += 1;
300        self.clock += 1;
301        if let Some(ts) = self.entries.get_mut(&page_id) {
302            *ts = self.clock; self.total_hits += 1; true
303        } else { false }
304    }
305
306    /// Insert a page, evicting LRU if at capacity.
307    pub fn insert(&mut self, page_id: usize) {
308        if self.capacity == 0 { return; }
309        if self.entries.len() >= self.capacity {
310            let lru = self.entries.iter().min_by_key(|&(_, ts)| *ts).map(|(&k, _)| k);
311            if let Some(k) = lru { self.entries.remove(&k); }
312        }
313        self.clock += 1;
314        self.entries.insert(page_id, self.clock);
315    }
316
317    /// Cache hit rate in [0.0, 1.0].
318    pub fn cache_hit_rate(&self) -> f64 {
319        if self.total_accesses == 0 { 0.0 } else { self.total_hits as f64 / self.total_accesses as f64 }
320    }
321}
322
323/// Finds the geometric medoid (point minimising sum of distances to all others).
324pub struct MedoidFinder;
325
326impl MedoidFinder {
327    pub fn find_medoid(vectors: &[Vec<f32>]) -> u32 {
328        if vectors.is_empty() { return 0; }
329        let (mut best_idx, mut best_sum) = (0u32, f32::MAX);
330        for i in 0..vectors.len() {
331            let sum: f32 = (0..vectors.len()).map(|j| l2_sq(&vectors[i], &vectors[j])).sum();
332            if sum < best_sum { best_sum = sum; best_idx = i as u32; }
333        }
334        best_idx
335    }
336}
337
338/// L2 squared distance.
339fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
340    a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum()
341}
342
343#[derive(Debug, Clone, PartialEq)]
344struct OrdF32Pair(f32, u32);
345impl Eq for OrdF32Pair {}
346impl PartialOrd for OrdF32Pair {
347    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { Some(self.cmp(other)) }
348}
349impl Ord for OrdF32Pair {
350    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
351        self.0.partial_cmp(&other.0).unwrap_or(std::cmp::Ordering::Equal).then(self.1.cmp(&other.1))
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    fn make_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
360        (0..n).map(|i| (0..dim).map(|d| (i * dim + d) as f32).collect()).collect()
361    }
362    fn default_cfg(r: usize, l: usize) -> VamanaConfig {
363        VamanaConfig { max_degree: r, search_list_size: l, ..Default::default() }
364    }
365
366    #[test]
367    fn build_graph_basic() {
368        let g = VamanaGraph::build(make_vecs(10, 4), default_cfg(4, 8)).unwrap();
369        assert_eq!(g.vectors.len(), 10);
370        for nb in &g.neighbors { assert!(nb.len() <= 4); }
371    }
372
373    #[test]
374    fn search_accuracy() {
375        let mut v = make_vecs(20, 4);
376        v.push(vec![0.1, 0.1, 0.1, 0.1]);
377        let g = VamanaGraph::build(v, default_cfg(8, 30)).unwrap();
378        let r = g.search(&[0.0; 4], 3);
379        assert!(r.iter().any(|&(id, _)| id == 20));
380    }
381
382    #[test]
383    fn robust_pruning_limits_degree() {
384        let g = VamanaGraph::build(make_vecs(50, 4), default_cfg(5, 16)).unwrap();
385        for nb in &g.neighbors { assert!(nb.len() <= 5); }
386    }
387
388    #[test]
389    fn disk_layout_roundtrip() {
390        let v = make_vecs(10, 4);
391        let g = VamanaGraph::build(v.clone(), VamanaConfig::default()).unwrap();
392        let d = DiskIndex::from_graph(&g, 16);
393        for i in 0..10 {
394            assert_eq!(d.nodes[i].node_id, i as u32);
395            assert_eq!(d.nodes[i].vector, v[i]);
396            assert_eq!(d.nodes[i].neighbors, g.neighbors[i]);
397        }
398    }
399
400    #[test]
401    fn page_cache_hits_and_misses() {
402        let mut c = PageCache::new(2);
403        assert!(!c.get(0));
404        c.insert(0);
405        assert!(c.get(0));
406        c.insert(1);
407        c.insert(2); // evicts 0
408        assert!(!c.get(0));
409        assert!(c.get(1));
410    }
411
412    #[test]
413    fn cache_hit_rate() {
414        let mut c = PageCache::new(4);
415        c.insert(0); c.insert(1);
416        assert!(c.get(0)); assert!(c.get(1)); assert!(!c.get(2));
417        assert!((c.cache_hit_rate() - 2.0 / 3.0).abs() < 1e-6);
418    }
419
420    #[test]
421    fn filtered_search() {
422        let mut v = make_vecs(15, 4);
423        v.push(vec![0.1; 4]);
424        let g = VamanaGraph::build(v, default_cfg(8, 20)).unwrap();
425        let mut d = DiskIndex::from_graph(&g, 32);
426        let r = d.search_with_filter(&[0.0; 4], |id| id % 2 == 0, 5);
427        for &(id, _) in &r { assert_eq!(id % 2, 0); }
428    }
429
430    #[test]
431    fn medoid_selection() {
432        let v = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
433        assert_eq!(MedoidFinder::find_medoid(&v), 3);
434    }
435
436    #[test]
437    fn empty_dataset() {
438        let g = VamanaGraph::build(vec![], VamanaConfig::default()).unwrap();
439        assert!(g.vectors.is_empty());
440        assert!(g.search(&[1.0, 2.0], 5).is_empty());
441    }
442
443    #[test]
444    fn single_vector() {
445        let g = VamanaGraph::build(vec![vec![1.0, 2.0, 3.0]], VamanaConfig::default()).unwrap();
446        assert!(g.neighbors[0].is_empty());
447        let r = g.search(&[1.0, 2.0, 3.0], 1);
448        assert_eq!(r.len(), 1);
449        assert_eq!(r[0].0, 0);
450    }
451
452    #[test]
453    fn io_stats_tracking() {
454        let g = VamanaGraph::build(make_vecs(10, 4), default_cfg(4, 10)).unwrap();
455        let mut d = DiskIndex::from_graph(&g, 2);
456        let (_, s) = d.search_disk(&[0.0; 4], 3, 10);
457        assert!(s.pages_read > 0);
458        assert_eq!(s.bytes_read, s.pages_read * 4096);
459    }
460
461    #[test]
462    fn disk_search_sorted_results() {
463        let g = VamanaGraph::build(make_vecs(20, 4), default_cfg(8, 20)).unwrap();
464        let mut d = DiskIndex::from_graph(&g, 32);
465        let (r, s) = d.search_disk(&[0.0; 4], 5, 20);
466        assert_eq!(r.len(), 5);
467        for w in r.windows(2) { assert!(w[0].1 <= w[1].1); }
468        assert!(s.pages_read + s.cache_hits > 0);
469    }
470
471    #[test]
472    fn config_validation() {
473        assert!(VamanaConfig { max_degree: 0, ..Default::default() }.validate().is_err());
474        assert!(VamanaConfig { alpha: 0.5, ..Default::default() }.validate().is_err());
475        assert!(VamanaConfig::default().validate().is_ok());
476    }
477}