Skip to main content

ruvector_diskann/
graph.rs

1//! Vamana graph construction with α-robust pruning
2//!
3//! Optimized with:
4//! - FlatVectors (contiguous memory, cache-friendly)
5//! - VisitedSet (O(1) clear via generation counter)
6//! - Rayon-parallel medoid finding
7
8use crate::distance::{l2_squared, FlatVectors, VisitedSet};
9use crate::error::{DiskAnnError, Result};
10use rayon::prelude::*;
11use std::cmp::Ordering;
12use std::collections::BinaryHeap;
13
14#[derive(Clone)]
15struct Candidate {
16    id: u32,
17    distance: f32,
18}
19
20impl PartialEq for Candidate {
21    fn eq(&self, other: &Self) -> bool {
22        self.distance == other.distance
23    }
24}
25impl Eq for Candidate {}
26impl PartialOrd for Candidate {
27    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
28        Some(self.cmp(other))
29    }
30}
31impl Ord for Candidate {
32    fn cmp(&self, other: &Self) -> Ordering {
33        other
34            .distance
35            .partial_cmp(&self.distance)
36            .unwrap_or(Ordering::Equal)
37    }
38}
39
40struct MaxCandidate {
41    id: u32,
42    distance: f32,
43}
44impl PartialEq for MaxCandidate {
45    fn eq(&self, other: &Self) -> bool {
46        self.distance == other.distance
47    }
48}
49impl Eq for MaxCandidate {}
50impl PartialOrd for MaxCandidate {
51    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
52        Some(self.cmp(other))
53    }
54}
55impl Ord for MaxCandidate {
56    fn cmp(&self, other: &Self) -> Ordering {
57        self.distance
58            .partial_cmp(&other.distance)
59            .unwrap_or(Ordering::Equal)
60    }
61}
62
63/// Vamana graph with bounded out-degree
64pub struct VamanaGraph {
65    pub neighbors: Vec<Vec<u32>>,
66    pub medoid: u32,
67    pub max_degree: usize,
68    pub build_beam: usize,
69    pub alpha: f32,
70}
71
72impl VamanaGraph {
73    pub fn new(n: usize, max_degree: usize, build_beam: usize, alpha: f32) -> Self {
74        Self {
75            neighbors: vec![Vec::new(); n],
76            medoid: 0,
77            max_degree,
78            build_beam,
79            alpha,
80        }
81    }
82
83    /// Build the Vamana graph over flat vector storage
84    pub fn build(&mut self, vectors: &FlatVectors) -> Result<()> {
85        let n = vectors.len();
86        if n == 0 {
87            return Err(DiskAnnError::Empty);
88        }
89
90        self.medoid = self.find_medoid_parallel(vectors);
91        self.init_random_graph(n);
92
93        let passes = if self.alpha > 1.0 { 2 } else { 1 };
94        for pass in 0..passes {
95            let alpha = if pass == 0 { 1.0 } else { self.alpha };
96
97            let mut order: Vec<u32> = (0..n as u32).collect();
98            {
99                use rand::prelude::*;
100                order.shuffle(&mut rand::thread_rng());
101            }
102
103            // Reusable visited set (O(1) clear per search)
104            let mut visited = VisitedSet::new(n);
105
106            for &node in &order {
107                let (candidates, _) = self.greedy_search_fast(
108                    vectors,
109                    vectors.get(node as usize),
110                    self.build_beam,
111                    &mut visited,
112                );
113
114                let pruned = self.robust_prune(vectors, node, &candidates, alpha);
115                self.neighbors[node as usize] = pruned.clone();
116
117                for &neighbor in &pruned {
118                    let nid = neighbor as usize;
119                    if !self.neighbors[nid].contains(&node) {
120                        if self.neighbors[nid].len() < self.max_degree {
121                            self.neighbors[nid].push(node);
122                        } else {
123                            let mut combined: Vec<u32> = self.neighbors[nid].clone();
124                            combined.push(node);
125                            let repruned = self.robust_prune(vectors, neighbor, &combined, alpha);
126                            self.neighbors[nid] = repruned;
127                        }
128                    }
129                }
130            }
131        }
132
133        Ok(())
134    }
135
136    /// Greedy beam search with reusable VisitedSet (zero-alloc per query)
137    pub fn greedy_search_fast(
138        &self,
139        vectors: &FlatVectors,
140        query: &[f32],
141        beam_width: usize,
142        visited: &mut VisitedSet,
143    ) -> (Vec<u32>, usize) {
144        visited.clear();
145
146        let mut candidates = BinaryHeap::<Candidate>::new();
147        let mut best = BinaryHeap::<MaxCandidate>::new();
148
149        let start = self.medoid;
150        let start_dist = l2_squared(vectors.get(start as usize), query);
151        candidates.push(Candidate {
152            id: start,
153            distance: start_dist,
154        });
155        best.push(MaxCandidate {
156            id: start,
157            distance: start_dist,
158        });
159        visited.insert(start);
160
161        let mut visit_count = 1usize;
162
163        while let Some(current) = candidates.pop() {
164            if best.len() >= beam_width {
165                if let Some(worst) = best.peek() {
166                    if current.distance > worst.distance {
167                        break;
168                    }
169                }
170            }
171
172            for &neighbor in &self.neighbors[current.id as usize] {
173                if visited.contains(neighbor) {
174                    continue;
175                }
176                visited.insert(neighbor);
177                visit_count += 1;
178
179                let dist = l2_squared(vectors.get(neighbor as usize), query);
180
181                let dominated =
182                    best.len() >= beam_width && best.peek().map_or(false, |w| dist >= w.distance);
183
184                if !dominated {
185                    candidates.push(Candidate {
186                        id: neighbor,
187                        distance: dist,
188                    });
189                    best.push(MaxCandidate {
190                        id: neighbor,
191                        distance: dist,
192                    });
193                    if best.len() > beam_width {
194                        best.pop();
195                    }
196                }
197            }
198        }
199
200        let mut result: Vec<(u32, f32)> = best.into_iter().map(|c| (c.id, c.distance)).collect();
201        result.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
202        let ids: Vec<u32> = result.into_iter().map(|(id, _)| id).collect();
203
204        (ids, visit_count)
205    }
206
207    /// Public search entry point (allocates its own VisitedSet)
208    pub fn greedy_search(
209        &self,
210        vectors: &FlatVectors,
211        query: &[f32],
212        beam_width: usize,
213    ) -> (Vec<u32>, usize) {
214        let mut visited = VisitedSet::new(vectors.len());
215        self.greedy_search_fast(vectors, query, beam_width, &mut visited)
216    }
217
218    fn robust_prune(
219        &self,
220        vectors: &FlatVectors,
221        node: u32,
222        candidates: &[u32],
223        alpha: f32,
224    ) -> Vec<u32> {
225        if candidates.is_empty() {
226            return Vec::new();
227        }
228
229        let node_vec = vectors.get(node as usize);
230        let mut sorted: Vec<(u32, f32)> = candidates
231            .iter()
232            .filter(|&&c| c != node)
233            .map(|&c| (c, l2_squared(vectors.get(c as usize), node_vec)))
234            .collect();
235        sorted.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
236
237        let mut result = Vec::with_capacity(self.max_degree);
238        for (cand_id, cand_dist) in &sorted {
239            if result.len() >= self.max_degree {
240                break;
241            }
242            let dominated = result.iter().any(|&selected: &u32| {
243                let inter_dist = l2_squared(
244                    vectors.get(selected as usize),
245                    vectors.get(*cand_id as usize),
246                );
247                alpha * inter_dist <= *cand_dist
248            });
249            if !dominated {
250                result.push(*cand_id);
251            }
252        }
253        result
254    }
255
256    /// Parallel medoid finding using rayon
257    fn find_medoid_parallel(&self, vectors: &FlatVectors) -> u32 {
258        let n = vectors.len();
259        let dim = vectors.dim;
260
261        // Compute centroid in parallel
262        let centroid: Vec<f32> = (0..dim)
263            .into_par_iter()
264            .map(|d| {
265                let mut sum = 0.0f32;
266                for i in 0..n {
267                    sum += vectors.get(i)[d];
268                }
269                sum / n as f32
270            })
271            .collect();
272
273        // Find closest point to centroid in parallel
274        (0..n as u32)
275            .into_par_iter()
276            .map(|i| (i, l2_squared(vectors.get(i as usize), &centroid)))
277            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal))
278            .map(|(id, _)| id)
279            .unwrap_or(0)
280    }
281
282    fn init_random_graph(&mut self, n: usize) {
283        use rand::prelude::*;
284        let mut rng = rand::thread_rng();
285        let degree = self.max_degree.min(n - 1);
286
287        for i in 0..n {
288            let mut neighbors = Vec::with_capacity(degree);
289            let mut attempts = 0;
290            while neighbors.len() < degree && attempts < degree * 3 {
291                let j = rng.gen_range(0..n) as u32;
292                if j != i as u32 && !neighbors.contains(&j) {
293                    neighbors.push(j);
294                }
295                attempts += 1;
296            }
297            self.neighbors[i] = neighbors;
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    fn random_flat(n: usize, dim: usize) -> FlatVectors {
307        use rand::prelude::*;
308        let mut rng = rand::thread_rng();
309        let mut fv = FlatVectors::with_capacity(dim, n);
310        for _ in 0..n {
311            let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
312            fv.push(&v);
313        }
314        fv
315    }
316
317    #[test]
318    fn test_vamana_build_and_search() {
319        let vectors = random_flat(200, 32);
320        let mut graph = VamanaGraph::new(200, 32, 64, 1.2);
321        graph.build(&vectors).unwrap();
322
323        let (results, _) = graph.greedy_search(&vectors, vectors.get(42), 10);
324        assert!(!results.is_empty());
325        assert!(results.contains(&42));
326    }
327
328    #[test]
329    fn test_vamana_bounded_degree() {
330        let vectors = random_flat(100, 16);
331        let mut graph = VamanaGraph::new(100, 8, 32, 1.2);
332        graph.build(&vectors).unwrap();
333
334        for neighbors in &graph.neighbors {
335            assert!(neighbors.len() <= 8);
336        }
337    }
338}