Skip to main content

ruvector_diskann/
graph.rs

1//! Vamana graph construction with α-robust pruning
2//!
3//! Optimized with:
4//! - FlatVectors (contiguous memory, cache-friendly)
5//! - VisitedSet (O(1) clear via generation counter)
6//! - Rayon-parallel medoid finding
7
8use crate::distance::{l2_squared, FlatVectors, VisitedSet};
9use crate::error::{DiskAnnError, Result};
10use rayon::prelude::*;
11use std::collections::BinaryHeap;
12use std::cmp::Ordering;
13
14#[derive(Clone)]
15struct Candidate {
16    id: u32,
17    distance: f32,
18}
19
20impl PartialEq for Candidate {
21    fn eq(&self, other: &Self) -> bool { self.distance == other.distance }
22}
23impl Eq for Candidate {}
24impl PartialOrd for Candidate {
25    fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
26}
27impl Ord for Candidate {
28    fn cmp(&self, other: &Self) -> Ordering {
29        other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal)
30    }
31}
32
33struct MaxCandidate {
34    id: u32,
35    distance: f32,
36}
37impl PartialEq for MaxCandidate {
38    fn eq(&self, other: &Self) -> bool { self.distance == other.distance }
39}
40impl Eq for MaxCandidate {}
41impl PartialOrd for MaxCandidate {
42    fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
43}
44impl Ord for MaxCandidate {
45    fn cmp(&self, other: &Self) -> Ordering {
46        self.distance.partial_cmp(&other.distance).unwrap_or(Ordering::Equal)
47    }
48}
49
50/// Vamana graph with bounded out-degree
51pub struct VamanaGraph {
52    pub neighbors: Vec<Vec<u32>>,
53    pub medoid: u32,
54    pub max_degree: usize,
55    pub build_beam: usize,
56    pub alpha: f32,
57}
58
59impl VamanaGraph {
60    pub fn new(n: usize, max_degree: usize, build_beam: usize, alpha: f32) -> Self {
61        Self {
62            neighbors: vec![Vec::new(); n],
63            medoid: 0,
64            max_degree,
65            build_beam,
66            alpha,
67        }
68    }
69
70    /// Build the Vamana graph over flat vector storage
71    pub fn build(&mut self, vectors: &FlatVectors) -> Result<()> {
72        let n = vectors.len();
73        if n == 0 {
74            return Err(DiskAnnError::Empty);
75        }
76
77        self.medoid = self.find_medoid_parallel(vectors);
78        self.init_random_graph(n);
79
80        let passes = if self.alpha > 1.0 { 2 } else { 1 };
81        for pass in 0..passes {
82            let alpha = if pass == 0 { 1.0 } else { self.alpha };
83
84            let mut order: Vec<u32> = (0..n as u32).collect();
85            {
86                use rand::prelude::*;
87                order.shuffle(&mut rand::thread_rng());
88            }
89
90            // Reusable visited set (O(1) clear per search)
91            let mut visited = VisitedSet::new(n);
92
93            for &node in &order {
94                let (candidates, _) =
95                    self.greedy_search_fast(vectors, vectors.get(node as usize), self.build_beam, &mut visited);
96
97                let pruned = self.robust_prune(vectors, node, &candidates, alpha);
98                self.neighbors[node as usize] = pruned.clone();
99
100                for &neighbor in &pruned {
101                    let nid = neighbor as usize;
102                    if !self.neighbors[nid].contains(&node) {
103                        if self.neighbors[nid].len() < self.max_degree {
104                            self.neighbors[nid].push(node);
105                        } else {
106                            let mut combined: Vec<u32> = self.neighbors[nid].clone();
107                            combined.push(node);
108                            let repruned = self.robust_prune(vectors, neighbor, &combined, alpha);
109                            self.neighbors[nid] = repruned;
110                        }
111                    }
112                }
113            }
114        }
115
116        Ok(())
117    }
118
119    /// Greedy beam search with reusable VisitedSet (zero-alloc per query)
120    pub fn greedy_search_fast(
121        &self,
122        vectors: &FlatVectors,
123        query: &[f32],
124        beam_width: usize,
125        visited: &mut VisitedSet,
126    ) -> (Vec<u32>, usize) {
127        visited.clear();
128
129        let mut candidates = BinaryHeap::<Candidate>::new();
130        let mut best = BinaryHeap::<MaxCandidate>::new();
131
132        let start = self.medoid;
133        let start_dist = l2_squared(vectors.get(start as usize), query);
134        candidates.push(Candidate { id: start, distance: start_dist });
135        best.push(MaxCandidate { id: start, distance: start_dist });
136        visited.insert(start);
137
138        let mut visit_count = 1usize;
139
140        while let Some(current) = candidates.pop() {
141            if best.len() >= beam_width {
142                if let Some(worst) = best.peek() {
143                    if current.distance > worst.distance {
144                        break;
145                    }
146                }
147            }
148
149            for &neighbor in &self.neighbors[current.id as usize] {
150                if visited.contains(neighbor) {
151                    continue;
152                }
153                visited.insert(neighbor);
154                visit_count += 1;
155
156                let dist = l2_squared(vectors.get(neighbor as usize), query);
157
158                let dominated = best.len() >= beam_width
159                    && best.peek().map_or(false, |w| dist >= w.distance);
160
161                if !dominated {
162                    candidates.push(Candidate { id: neighbor, distance: dist });
163                    best.push(MaxCandidate { id: neighbor, distance: dist });
164                    if best.len() > beam_width {
165                        best.pop();
166                    }
167                }
168            }
169        }
170
171        let mut result: Vec<(u32, f32)> = best.into_iter().map(|c| (c.id, c.distance)).collect();
172        result.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
173        let ids: Vec<u32> = result.into_iter().map(|(id, _)| id).collect();
174
175        (ids, visit_count)
176    }
177
178    /// Public search entry point (allocates its own VisitedSet)
179    pub fn greedy_search(
180        &self,
181        vectors: &FlatVectors,
182        query: &[f32],
183        beam_width: usize,
184    ) -> (Vec<u32>, usize) {
185        let mut visited = VisitedSet::new(vectors.len());
186        self.greedy_search_fast(vectors, query, beam_width, &mut visited)
187    }
188
189    fn robust_prune(
190        &self,
191        vectors: &FlatVectors,
192        node: u32,
193        candidates: &[u32],
194        alpha: f32,
195    ) -> Vec<u32> {
196        if candidates.is_empty() {
197            return Vec::new();
198        }
199
200        let node_vec = vectors.get(node as usize);
201        let mut sorted: Vec<(u32, f32)> = candidates
202            .iter()
203            .filter(|&&c| c != node)
204            .map(|&c| (c, l2_squared(vectors.get(c as usize), node_vec)))
205            .collect();
206        sorted.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
207
208        let mut result = Vec::with_capacity(self.max_degree);
209        for (cand_id, cand_dist) in &sorted {
210            if result.len() >= self.max_degree {
211                break;
212            }
213            let dominated = result.iter().any(|&selected: &u32| {
214                let inter_dist = l2_squared(vectors.get(selected as usize), vectors.get(*cand_id as usize));
215                alpha * inter_dist <= *cand_dist
216            });
217            if !dominated {
218                result.push(*cand_id);
219            }
220        }
221        result
222    }
223
224    /// Parallel medoid finding using rayon
225    fn find_medoid_parallel(&self, vectors: &FlatVectors) -> u32 {
226        let n = vectors.len();
227        let dim = vectors.dim;
228
229        // Compute centroid in parallel
230        let centroid: Vec<f32> = (0..dim)
231            .into_par_iter()
232            .map(|d| {
233                let mut sum = 0.0f32;
234                for i in 0..n {
235                    sum += vectors.get(i)[d];
236                }
237                sum / n as f32
238            })
239            .collect();
240
241        // Find closest point to centroid in parallel
242        (0..n as u32)
243            .into_par_iter()
244            .map(|i| (i, l2_squared(vectors.get(i as usize), &centroid)))
245            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal))
246            .map(|(id, _)| id)
247            .unwrap_or(0)
248    }
249
250    fn init_random_graph(&mut self, n: usize) {
251        use rand::prelude::*;
252        let mut rng = rand::thread_rng();
253        let degree = self.max_degree.min(n - 1);
254
255        for i in 0..n {
256            let mut neighbors = Vec::with_capacity(degree);
257            let mut attempts = 0;
258            while neighbors.len() < degree && attempts < degree * 3 {
259                let j = rng.gen_range(0..n) as u32;
260                if j != i as u32 && !neighbors.contains(&j) {
261                    neighbors.push(j);
262                }
263                attempts += 1;
264            }
265            self.neighbors[i] = neighbors;
266        }
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    fn random_flat(n: usize, dim: usize) -> FlatVectors {
275        use rand::prelude::*;
276        let mut rng = rand::thread_rng();
277        let mut fv = FlatVectors::with_capacity(dim, n);
278        for _ in 0..n {
279            let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
280            fv.push(&v);
281        }
282        fv
283    }
284
285    #[test]
286    fn test_vamana_build_and_search() {
287        let vectors = random_flat(200, 32);
288        let mut graph = VamanaGraph::new(200, 32, 64, 1.2);
289        graph.build(&vectors).unwrap();
290
291        let (results, _) = graph.greedy_search(&vectors, vectors.get(42), 10);
292        assert!(!results.is_empty());
293        assert!(results.contains(&42));
294    }
295
296    #[test]
297    fn test_vamana_bounded_degree() {
298        let vectors = random_flat(100, 16);
299        let mut graph = VamanaGraph::new(100, 8, 32, 1.2);
300        graph.build(&vectors).unwrap();
301
302        for neighbors in &graph.neighbors {
303            assert!(neighbors.len() <= 8);
304        }
305    }
306}