1use crate::distance::{l2_squared, FlatVectors, VisitedSet};
9use crate::error::{DiskAnnError, Result};
10use rayon::prelude::*;
11use std::collections::BinaryHeap;
12use std::cmp::Ordering;
13
14#[derive(Clone)]
15struct Candidate {
16 id: u32,
17 distance: f32,
18}
19
20impl PartialEq for Candidate {
21 fn eq(&self, other: &Self) -> bool { self.distance == other.distance }
22}
23impl Eq for Candidate {}
24impl PartialOrd for Candidate {
25 fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
26}
27impl Ord for Candidate {
28 fn cmp(&self, other: &Self) -> Ordering {
29 other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal)
30 }
31}
32
33struct MaxCandidate {
34 id: u32,
35 distance: f32,
36}
37impl PartialEq for MaxCandidate {
38 fn eq(&self, other: &Self) -> bool { self.distance == other.distance }
39}
40impl Eq for MaxCandidate {}
41impl PartialOrd for MaxCandidate {
42 fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
43}
44impl Ord for MaxCandidate {
45 fn cmp(&self, other: &Self) -> Ordering {
46 self.distance.partial_cmp(&other.distance).unwrap_or(Ordering::Equal)
47 }
48}
49
50pub struct VamanaGraph {
52 pub neighbors: Vec<Vec<u32>>,
53 pub medoid: u32,
54 pub max_degree: usize,
55 pub build_beam: usize,
56 pub alpha: f32,
57}
58
59impl VamanaGraph {
60 pub fn new(n: usize, max_degree: usize, build_beam: usize, alpha: f32) -> Self {
61 Self {
62 neighbors: vec![Vec::new(); n],
63 medoid: 0,
64 max_degree,
65 build_beam,
66 alpha,
67 }
68 }
69
70 pub fn build(&mut self, vectors: &FlatVectors) -> Result<()> {
72 let n = vectors.len();
73 if n == 0 {
74 return Err(DiskAnnError::Empty);
75 }
76
77 self.medoid = self.find_medoid_parallel(vectors);
78 self.init_random_graph(n);
79
80 let passes = if self.alpha > 1.0 { 2 } else { 1 };
81 for pass in 0..passes {
82 let alpha = if pass == 0 { 1.0 } else { self.alpha };
83
84 let mut order: Vec<u32> = (0..n as u32).collect();
85 {
86 use rand::prelude::*;
87 order.shuffle(&mut rand::thread_rng());
88 }
89
90 let mut visited = VisitedSet::new(n);
92
93 for &node in &order {
94 let (candidates, _) =
95 self.greedy_search_fast(vectors, vectors.get(node as usize), self.build_beam, &mut visited);
96
97 let pruned = self.robust_prune(vectors, node, &candidates, alpha);
98 self.neighbors[node as usize] = pruned.clone();
99
100 for &neighbor in &pruned {
101 let nid = neighbor as usize;
102 if !self.neighbors[nid].contains(&node) {
103 if self.neighbors[nid].len() < self.max_degree {
104 self.neighbors[nid].push(node);
105 } else {
106 let mut combined: Vec<u32> = self.neighbors[nid].clone();
107 combined.push(node);
108 let repruned = self.robust_prune(vectors, neighbor, &combined, alpha);
109 self.neighbors[nid] = repruned;
110 }
111 }
112 }
113 }
114 }
115
116 Ok(())
117 }
118
119 pub fn greedy_search_fast(
121 &self,
122 vectors: &FlatVectors,
123 query: &[f32],
124 beam_width: usize,
125 visited: &mut VisitedSet,
126 ) -> (Vec<u32>, usize) {
127 visited.clear();
128
129 let mut candidates = BinaryHeap::<Candidate>::new();
130 let mut best = BinaryHeap::<MaxCandidate>::new();
131
132 let start = self.medoid;
133 let start_dist = l2_squared(vectors.get(start as usize), query);
134 candidates.push(Candidate { id: start, distance: start_dist });
135 best.push(MaxCandidate { id: start, distance: start_dist });
136 visited.insert(start);
137
138 let mut visit_count = 1usize;
139
140 while let Some(current) = candidates.pop() {
141 if best.len() >= beam_width {
142 if let Some(worst) = best.peek() {
143 if current.distance > worst.distance {
144 break;
145 }
146 }
147 }
148
149 for &neighbor in &self.neighbors[current.id as usize] {
150 if visited.contains(neighbor) {
151 continue;
152 }
153 visited.insert(neighbor);
154 visit_count += 1;
155
156 let dist = l2_squared(vectors.get(neighbor as usize), query);
157
158 let dominated = best.len() >= beam_width
159 && best.peek().map_or(false, |w| dist >= w.distance);
160
161 if !dominated {
162 candidates.push(Candidate { id: neighbor, distance: dist });
163 best.push(MaxCandidate { id: neighbor, distance: dist });
164 if best.len() > beam_width {
165 best.pop();
166 }
167 }
168 }
169 }
170
171 let mut result: Vec<(u32, f32)> = best.into_iter().map(|c| (c.id, c.distance)).collect();
172 result.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
173 let ids: Vec<u32> = result.into_iter().map(|(id, _)| id).collect();
174
175 (ids, visit_count)
176 }
177
178 pub fn greedy_search(
180 &self,
181 vectors: &FlatVectors,
182 query: &[f32],
183 beam_width: usize,
184 ) -> (Vec<u32>, usize) {
185 let mut visited = VisitedSet::new(vectors.len());
186 self.greedy_search_fast(vectors, query, beam_width, &mut visited)
187 }
188
189 fn robust_prune(
190 &self,
191 vectors: &FlatVectors,
192 node: u32,
193 candidates: &[u32],
194 alpha: f32,
195 ) -> Vec<u32> {
196 if candidates.is_empty() {
197 return Vec::new();
198 }
199
200 let node_vec = vectors.get(node as usize);
201 let mut sorted: Vec<(u32, f32)> = candidates
202 .iter()
203 .filter(|&&c| c != node)
204 .map(|&c| (c, l2_squared(vectors.get(c as usize), node_vec)))
205 .collect();
206 sorted.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
207
208 let mut result = Vec::with_capacity(self.max_degree);
209 for (cand_id, cand_dist) in &sorted {
210 if result.len() >= self.max_degree {
211 break;
212 }
213 let dominated = result.iter().any(|&selected: &u32| {
214 let inter_dist = l2_squared(vectors.get(selected as usize), vectors.get(*cand_id as usize));
215 alpha * inter_dist <= *cand_dist
216 });
217 if !dominated {
218 result.push(*cand_id);
219 }
220 }
221 result
222 }
223
224 fn find_medoid_parallel(&self, vectors: &FlatVectors) -> u32 {
226 let n = vectors.len();
227 let dim = vectors.dim;
228
229 let centroid: Vec<f32> = (0..dim)
231 .into_par_iter()
232 .map(|d| {
233 let mut sum = 0.0f32;
234 for i in 0..n {
235 sum += vectors.get(i)[d];
236 }
237 sum / n as f32
238 })
239 .collect();
240
241 (0..n as u32)
243 .into_par_iter()
244 .map(|i| (i, l2_squared(vectors.get(i as usize), ¢roid)))
245 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal))
246 .map(|(id, _)| id)
247 .unwrap_or(0)
248 }
249
250 fn init_random_graph(&mut self, n: usize) {
251 use rand::prelude::*;
252 let mut rng = rand::thread_rng();
253 let degree = self.max_degree.min(n - 1);
254
255 for i in 0..n {
256 let mut neighbors = Vec::with_capacity(degree);
257 let mut attempts = 0;
258 while neighbors.len() < degree && attempts < degree * 3 {
259 let j = rng.gen_range(0..n) as u32;
260 if j != i as u32 && !neighbors.contains(&j) {
261 neighbors.push(j);
262 }
263 attempts += 1;
264 }
265 self.neighbors[i] = neighbors;
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 fn random_flat(n: usize, dim: usize) -> FlatVectors {
275 use rand::prelude::*;
276 let mut rng = rand::thread_rng();
277 let mut fv = FlatVectors::with_capacity(dim, n);
278 for _ in 0..n {
279 let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
280 fv.push(&v);
281 }
282 fv
283 }
284
285 #[test]
286 fn test_vamana_build_and_search() {
287 let vectors = random_flat(200, 32);
288 let mut graph = VamanaGraph::new(200, 32, 64, 1.2);
289 graph.build(&vectors).unwrap();
290
291 let (results, _) = graph.greedy_search(&vectors, vectors.get(42), 10);
292 assert!(!results.is_empty());
293 assert!(results.contains(&42));
294 }
295
296 #[test]
297 fn test_vamana_bounded_degree() {
298 let vectors = random_flat(100, 16);
299 let mut graph = VamanaGraph::new(100, 8, 32, 1.2);
300 graph.build(&vectors).unwrap();
301
302 for neighbors in &graph.neighbors {
303 assert!(neighbors.len() <= 8);
304 }
305 }
306}