1use crate::distance::{l2_squared, FlatVectors, VisitedSet};
9use crate::error::{DiskAnnError, Result};
10use rayon::prelude::*;
11use std::cmp::Ordering;
12use std::collections::BinaryHeap;
13
14#[derive(Clone)]
15struct Candidate {
16 id: u32,
17 distance: f32,
18}
19
20impl PartialEq for Candidate {
21 fn eq(&self, other: &Self) -> bool {
22 self.distance == other.distance
23 }
24}
25impl Eq for Candidate {}
26impl PartialOrd for Candidate {
27 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
28 Some(self.cmp(other))
29 }
30}
31impl Ord for Candidate {
32 fn cmp(&self, other: &Self) -> Ordering {
33 other
34 .distance
35 .partial_cmp(&self.distance)
36 .unwrap_or(Ordering::Equal)
37 }
38}
39
40struct MaxCandidate {
41 id: u32,
42 distance: f32,
43}
44impl PartialEq for MaxCandidate {
45 fn eq(&self, other: &Self) -> bool {
46 self.distance == other.distance
47 }
48}
49impl Eq for MaxCandidate {}
50impl PartialOrd for MaxCandidate {
51 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
52 Some(self.cmp(other))
53 }
54}
55impl Ord for MaxCandidate {
56 fn cmp(&self, other: &Self) -> Ordering {
57 self.distance
58 .partial_cmp(&other.distance)
59 .unwrap_or(Ordering::Equal)
60 }
61}
62
63pub struct VamanaGraph {
65 pub neighbors: Vec<Vec<u32>>,
66 pub medoid: u32,
67 pub max_degree: usize,
68 pub build_beam: usize,
69 pub alpha: f32,
70}
71
72impl VamanaGraph {
73 pub fn new(n: usize, max_degree: usize, build_beam: usize, alpha: f32) -> Self {
74 Self {
75 neighbors: vec![Vec::new(); n],
76 medoid: 0,
77 max_degree,
78 build_beam,
79 alpha,
80 }
81 }
82
83 pub fn build(&mut self, vectors: &FlatVectors) -> Result<()> {
85 let n = vectors.len();
86 if n == 0 {
87 return Err(DiskAnnError::Empty);
88 }
89
90 self.medoid = self.find_medoid_parallel(vectors);
91 self.init_random_graph(n);
92
93 let passes = if self.alpha > 1.0 { 2 } else { 1 };
94 for pass in 0..passes {
95 let alpha = if pass == 0 { 1.0 } else { self.alpha };
96
97 let mut order: Vec<u32> = (0..n as u32).collect();
98 {
99 use rand::prelude::*;
100 order.shuffle(&mut rand::thread_rng());
101 }
102
103 let mut visited = VisitedSet::new(n);
105
106 for &node in &order {
107 let (candidates, _) = self.greedy_search_fast(
108 vectors,
109 vectors.get(node as usize),
110 self.build_beam,
111 &mut visited,
112 );
113
114 let pruned = self.robust_prune(vectors, node, &candidates, alpha);
115 self.neighbors[node as usize] = pruned.clone();
116
117 for &neighbor in &pruned {
118 let nid = neighbor as usize;
119 if !self.neighbors[nid].contains(&node) {
120 if self.neighbors[nid].len() < self.max_degree {
121 self.neighbors[nid].push(node);
122 } else {
123 let mut combined: Vec<u32> = self.neighbors[nid].clone();
124 combined.push(node);
125 let repruned = self.robust_prune(vectors, neighbor, &combined, alpha);
126 self.neighbors[nid] = repruned;
127 }
128 }
129 }
130 }
131 }
132
133 Ok(())
134 }
135
136 pub fn greedy_search_fast(
138 &self,
139 vectors: &FlatVectors,
140 query: &[f32],
141 beam_width: usize,
142 visited: &mut VisitedSet,
143 ) -> (Vec<u32>, usize) {
144 visited.clear();
145
146 let mut candidates = BinaryHeap::<Candidate>::new();
147 let mut best = BinaryHeap::<MaxCandidate>::new();
148
149 let start = self.medoid;
150 let start_dist = l2_squared(vectors.get(start as usize), query);
151 candidates.push(Candidate {
152 id: start,
153 distance: start_dist,
154 });
155 best.push(MaxCandidate {
156 id: start,
157 distance: start_dist,
158 });
159 visited.insert(start);
160
161 let mut visit_count = 1usize;
162
163 while let Some(current) = candidates.pop() {
164 if best.len() >= beam_width {
165 if let Some(worst) = best.peek() {
166 if current.distance > worst.distance {
167 break;
168 }
169 }
170 }
171
172 for &neighbor in &self.neighbors[current.id as usize] {
173 if visited.contains(neighbor) {
174 continue;
175 }
176 visited.insert(neighbor);
177 visit_count += 1;
178
179 let dist = l2_squared(vectors.get(neighbor as usize), query);
180
181 let dominated =
182 best.len() >= beam_width && best.peek().map_or(false, |w| dist >= w.distance);
183
184 if !dominated {
185 candidates.push(Candidate {
186 id: neighbor,
187 distance: dist,
188 });
189 best.push(MaxCandidate {
190 id: neighbor,
191 distance: dist,
192 });
193 if best.len() > beam_width {
194 best.pop();
195 }
196 }
197 }
198 }
199
200 let mut result: Vec<(u32, f32)> = best.into_iter().map(|c| (c.id, c.distance)).collect();
201 result.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
202 let ids: Vec<u32> = result.into_iter().map(|(id, _)| id).collect();
203
204 (ids, visit_count)
205 }
206
207 pub fn greedy_search(
209 &self,
210 vectors: &FlatVectors,
211 query: &[f32],
212 beam_width: usize,
213 ) -> (Vec<u32>, usize) {
214 let mut visited = VisitedSet::new(vectors.len());
215 self.greedy_search_fast(vectors, query, beam_width, &mut visited)
216 }
217
218 fn robust_prune(
219 &self,
220 vectors: &FlatVectors,
221 node: u32,
222 candidates: &[u32],
223 alpha: f32,
224 ) -> Vec<u32> {
225 if candidates.is_empty() {
226 return Vec::new();
227 }
228
229 let node_vec = vectors.get(node as usize);
230 let mut sorted: Vec<(u32, f32)> = candidates
231 .iter()
232 .filter(|&&c| c != node)
233 .map(|&c| (c, l2_squared(vectors.get(c as usize), node_vec)))
234 .collect();
235 sorted.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
236
237 let mut result = Vec::with_capacity(self.max_degree);
238 for (cand_id, cand_dist) in &sorted {
239 if result.len() >= self.max_degree {
240 break;
241 }
242 let dominated = result.iter().any(|&selected: &u32| {
243 let inter_dist = l2_squared(
244 vectors.get(selected as usize),
245 vectors.get(*cand_id as usize),
246 );
247 alpha * inter_dist <= *cand_dist
248 });
249 if !dominated {
250 result.push(*cand_id);
251 }
252 }
253 result
254 }
255
256 fn find_medoid_parallel(&self, vectors: &FlatVectors) -> u32 {
258 let n = vectors.len();
259 let dim = vectors.dim;
260
261 let centroid: Vec<f32> = (0..dim)
263 .into_par_iter()
264 .map(|d| {
265 let mut sum = 0.0f32;
266 for i in 0..n {
267 sum += vectors.get(i)[d];
268 }
269 sum / n as f32
270 })
271 .collect();
272
273 (0..n as u32)
275 .into_par_iter()
276 .map(|i| (i, l2_squared(vectors.get(i as usize), ¢roid)))
277 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal))
278 .map(|(id, _)| id)
279 .unwrap_or(0)
280 }
281
282 fn init_random_graph(&mut self, n: usize) {
283 use rand::prelude::*;
284 let mut rng = rand::thread_rng();
285 let degree = self.max_degree.min(n - 1);
286
287 for i in 0..n {
288 let mut neighbors = Vec::with_capacity(degree);
289 let mut attempts = 0;
290 while neighbors.len() < degree && attempts < degree * 3 {
291 let j = rng.gen_range(0..n) as u32;
292 if j != i as u32 && !neighbors.contains(&j) {
293 neighbors.push(j);
294 }
295 attempts += 1;
296 }
297 self.neighbors[i] = neighbors;
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 fn random_flat(n: usize, dim: usize) -> FlatVectors {
307 use rand::prelude::*;
308 let mut rng = rand::thread_rng();
309 let mut fv = FlatVectors::with_capacity(dim, n);
310 for _ in 0..n {
311 let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
312 fv.push(&v);
313 }
314 fv
315 }
316
317 #[test]
318 fn test_vamana_build_and_search() {
319 let vectors = random_flat(200, 32);
320 let mut graph = VamanaGraph::new(200, 32, 64, 1.2);
321 graph.build(&vectors).unwrap();
322
323 let (results, _) = graph.greedy_search(&vectors, vectors.get(42), 10);
324 assert!(!results.is_empty());
325 assert!(results.contains(&42));
326 }
327
328 #[test]
329 fn test_vamana_bounded_degree() {
330 let vectors = random_flat(100, 16);
331 let mut graph = VamanaGraph::new(100, 8, 32, 1.2);
332 graph.build(&vectors).unwrap();
333
334 for neighbors in &graph.neighbors {
335 assert!(neighbors.len() <= 8);
336 }
337 }
338}