1use crate::error::{Result, RuvectorError};
20use serde::{Deserialize, Serialize};
21use std::cmp::Reverse;
22use std::collections::{BinaryHeap, HashMap, HashSet};
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct VamanaConfig {
27 pub max_degree: usize,
29 pub search_list_size: usize,
31 pub alpha: f32,
33 pub num_build_threads: usize,
35 pub ssd_page_size: usize,
37}
38
39impl Default for VamanaConfig {
40 fn default() -> Self {
41 Self {
42 max_degree: 32,
43 search_list_size: 64,
44 alpha: 1.2,
45 num_build_threads: 1,
46 ssd_page_size: 4096,
47 }
48 }
49}
50
51impl VamanaConfig {
52 pub fn validate(&self) -> Result<()> {
54 if self.max_degree == 0 {
55 return Err(RuvectorError::InvalidParameter(
56 "max_degree must be > 0".into(),
57 ));
58 }
59 if self.search_list_size < 1 {
60 return Err(RuvectorError::InvalidParameter(
61 "search_list_size must be >= 1".into(),
62 ));
63 }
64 if self.alpha < 1.0 {
65 return Err(RuvectorError::InvalidParameter(
66 "alpha must be >= 1.0".into(),
67 ));
68 }
69 Ok(())
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct VamanaGraph {
76 pub neighbors: Vec<Vec<u32>>,
78 pub vectors: Vec<Vec<f32>>,
80 pub medoid: u32,
82 pub config: VamanaConfig,
84}
85
86impl VamanaGraph {
87 pub fn build(vectors: Vec<Vec<f32>>, config: VamanaConfig) -> Result<Self> {
89 config.validate()?;
90 let n = vectors.len();
91 if n == 0 {
92 return Ok(Self {
93 neighbors: vec![],
94 vectors: vec![],
95 medoid: 0,
96 config,
97 });
98 }
99 let dim = vectors[0].len();
100 for v in &vectors {
101 if v.len() != dim {
102 return Err(RuvectorError::DimensionMismatch {
103 expected: dim,
104 actual: v.len(),
105 });
106 }
107 }
108 let medoid = MedoidFinder::find_medoid(&vectors);
109 let mut graph = Self {
110 neighbors: vec![vec![]; n],
111 vectors,
112 medoid,
113 config,
114 };
115 for i in 0..n {
117 let mut nb = Vec::new();
118 for j in 0..n.min(graph.config.max_degree + 1) {
119 if j != i {
120 nb.push(j as u32);
121 }
122 if nb.len() >= graph.config.max_degree {
123 break;
124 }
125 }
126 graph.neighbors[i] = nb;
127 }
128 for i in 0..n {
130 let query = graph.vectors[i].clone();
131 let (cands, _) = graph.greedy_search_internal(&query, graph.config.search_list_size);
132 let mut cset: Vec<u32> = cands.into_iter().filter(|&c| c != i as u32).collect();
133 for &nb in &graph.neighbors[i] {
134 if !cset.contains(&nb) {
135 cset.push(nb);
136 }
137 }
138 let pruned = graph.robust_prune(i as u32, &cset);
139 graph.neighbors[i] = pruned.clone();
140 for &nb in &pruned {
141 let ni = nb as usize;
142 if !graph.neighbors[ni].contains(&(i as u32)) {
143 graph.neighbors[ni].push(i as u32);
144 if graph.neighbors[ni].len() > graph.config.max_degree {
145 let nbs = graph.neighbors[ni].clone();
146 graph.neighbors[ni] = graph.robust_prune(nb, &nbs);
147 }
148 }
149 }
150 }
151 Ok(graph)
152 }
153
154 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(u32, f32)> {
156 if self.vectors.is_empty() {
157 return vec![];
158 }
159 let beam = self.config.search_list_size.max(top_k);
160 let (ids, dists) = self.greedy_search_internal(query, beam);
161 ids.into_iter().zip(dists).take(top_k).collect()
162 }
163
164 fn greedy_search_internal(&self, query: &[f32], list_size: usize) -> (Vec<u32>, Vec<f32>) {
165 let mut visited = HashSet::new();
166 let mut frontier: BinaryHeap<Reverse<OrdF32Pair>> = BinaryHeap::new();
167 let mut results: Vec<(f32, u32)> = Vec::new();
168 let start = self.medoid;
169 let d = l2_sq(&self.vectors[start as usize], query);
170 frontier.push(Reverse(OrdF32Pair(d, start)));
171 visited.insert(start);
172 results.push((d, start));
173 while let Some(Reverse(OrdF32Pair(_, node))) = frontier.pop() {
174 for &nb in &self.neighbors[node as usize] {
175 if visited.insert(nb) {
176 let dist = l2_sq(&self.vectors[nb as usize], query);
177 results.push((dist, nb));
178 frontier.push(Reverse(OrdF32Pair(dist, nb)));
179 }
180 }
181 if results.len() > list_size * 2 {
182 results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
183 results.truncate(list_size);
184 }
185 }
186 results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
187 results.truncate(list_size);
188 (
189 results.iter().map(|r| r.1).collect(),
190 results.iter().map(|r| r.0).collect(),
191 )
192 }
193
194 fn robust_prune(&self, node_id: u32, candidates: &[u32]) -> Vec<u32> {
196 let nv = &self.vectors[node_id as usize];
197 let mut scored: Vec<(f32, u32)> = candidates
198 .iter()
199 .filter(|&&c| c != node_id)
200 .map(|&c| (l2_sq(nv, &self.vectors[c as usize]), c))
201 .collect();
202 scored.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
203 let mut sel: Vec<u32> = Vec::new();
204 for (d2n, cand) in scored {
205 if sel.len() >= self.config.max_degree {
206 break;
207 }
208 let cv = &self.vectors[cand as usize];
209 if sel
210 .iter()
211 .all(|&s| d2n <= self.config.alpha * l2_sq(&self.vectors[s as usize], cv))
212 {
213 sel.push(cand);
214 }
215 }
216 sel
217 }
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct DiskNode {
223 pub node_id: u32,
224 pub neighbors: Vec<u32>,
225 pub vector: Vec<f32>,
226}
227
228#[derive(Debug, Clone, Default)]
230pub struct IOStats {
231 pub pages_read: usize,
232 pub bytes_read: usize,
233 pub cache_hits: usize,
234}
235
236#[derive(Debug)]
238pub struct DiskIndex {
239 nodes: Vec<DiskNode>,
240 page_size: usize,
241 medoid: u32,
242 cache: PageCache,
243}
244
245impl DiskIndex {
246 pub fn from_graph(graph: &VamanaGraph, cache_size_pages: usize) -> Self {
248 let nodes = (0..graph.vectors.len())
249 .map(|i| DiskNode {
250 node_id: i as u32,
251 neighbors: graph.neighbors[i].clone(),
252 vector: graph.vectors[i].clone(),
253 })
254 .collect();
255 Self {
256 nodes,
257 page_size: graph.config.ssd_page_size,
258 medoid: graph.medoid,
259 cache: PageCache::new(cache_size_pages),
260 }
261 }
262
263 pub fn search_disk(
265 &mut self,
266 query: &[f32],
267 top_k: usize,
268 beam_width: usize,
269 ) -> (Vec<(u32, f32)>, IOStats) {
270 let mut stats = IOStats::default();
271 if self.nodes.is_empty() {
272 return (vec![], stats);
273 }
274 let mut visited = HashSet::new();
275 let mut frontier: BinaryHeap<Reverse<OrdF32Pair>> = BinaryHeap::new();
276 let mut results: Vec<(f32, u32)> = Vec::new();
277 let start = self.medoid;
278 let d = l2_sq(&self.read_node(start, &mut stats).vector.clone(), query);
279 frontier.push(Reverse(OrdF32Pair(d, start)));
280 visited.insert(start);
281 results.push((d, start));
282 while let Some(Reverse(OrdF32Pair(_, cur))) = frontier.pop() {
283 let nbs = self.read_node(cur, &mut stats).neighbors.clone();
284 for nb in nbs {
285 if visited.insert(nb) {
286 let v = self.read_node(nb, &mut stats).vector.clone();
287 let dist = l2_sq(&v, query);
288 results.push((dist, nb));
289 frontier.push(Reverse(OrdF32Pair(dist, nb)));
290 }
291 }
292 if results.len() > beam_width * 2 {
293 results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
294 results.truncate(beam_width);
295 }
296 }
297 results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
298 results.truncate(top_k);
299 (results.iter().map(|r| (r.1, r.0)).collect(), stats)
300 }
301
302 fn read_node(&mut self, node_id: u32, stats: &mut IOStats) -> &DiskNode {
303 let page_id = node_id as usize;
304 if self.cache.get(page_id) {
305 stats.cache_hits += 1;
306 } else {
307 stats.pages_read += 1;
308 stats.bytes_read += self.page_size;
309 self.cache.insert(page_id);
310 }
311 &self.nodes[node_id as usize]
312 }
313
314 pub fn search_with_filter<F>(
317 &mut self,
318 query: &[f32],
319 filter_fn: F,
320 top_k: usize,
321 ) -> Vec<(u32, f32)>
322 where
323 F: Fn(u32) -> bool,
324 {
325 if self.nodes.is_empty() {
326 return vec![];
327 }
328 let mut visited = HashSet::new();
329 let mut frontier: BinaryHeap<Reverse<OrdF32Pair>> = BinaryHeap::new();
330 let mut results: Vec<(f32, u32)> = Vec::new();
331 let mut io = IOStats::default();
332 let start = self.medoid;
333 let d = l2_sq(&self.read_node(start, &mut io).vector.clone(), query);
334 frontier.push(Reverse(OrdF32Pair(d, start)));
335 visited.insert(start);
336 if filter_fn(start) {
337 results.push((d, start));
338 }
339 while let Some(Reverse(OrdF32Pair(_, cur))) = frontier.pop() {
340 let nbs = self.read_node(cur, &mut io).neighbors.clone();
341 for nb in nbs {
342 if visited.insert(nb) {
343 let v = self.read_node(nb, &mut io).vector.clone();
344 let dist = l2_sq(&v, query);
345 frontier.push(Reverse(OrdF32Pair(dist, nb)));
346 if filter_fn(nb) {
347 results.push((dist, nb));
348 }
349 }
350 }
351 }
352 results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
353 results.truncate(top_k);
354 results.iter().map(|r| (r.1, r.0)).collect()
355 }
356}
357
358#[derive(Debug)]
360pub struct PageCache {
361 capacity: usize,
362 clock: u64,
363 entries: HashMap<usize, u64>,
364 total_hits: u64,
365 total_accesses: u64,
366}
367
368impl PageCache {
369 pub fn new(capacity: usize) -> Self {
370 Self {
371 capacity,
372 clock: 0,
373 entries: HashMap::new(),
374 total_hits: 0,
375 total_accesses: 0,
376 }
377 }
378
379 pub fn get(&mut self, page_id: usize) -> bool {
381 self.total_accesses += 1;
382 self.clock += 1;
383 if let Some(ts) = self.entries.get_mut(&page_id) {
384 *ts = self.clock;
385 self.total_hits += 1;
386 true
387 } else {
388 false
389 }
390 }
391
392 pub fn insert(&mut self, page_id: usize) {
394 if self.capacity == 0 {
395 return;
396 }
397 if self.entries.len() >= self.capacity {
398 let lru = self
399 .entries
400 .iter()
401 .min_by_key(|&(_, ts)| *ts)
402 .map(|(&k, _)| k);
403 if let Some(k) = lru {
404 self.entries.remove(&k);
405 }
406 }
407 self.clock += 1;
408 self.entries.insert(page_id, self.clock);
409 }
410
411 pub fn cache_hit_rate(&self) -> f64 {
413 if self.total_accesses == 0 {
414 0.0
415 } else {
416 self.total_hits as f64 / self.total_accesses as f64
417 }
418 }
419}
420
421pub struct MedoidFinder;
423
424impl MedoidFinder {
425 pub fn find_medoid(vectors: &[Vec<f32>]) -> u32 {
426 if vectors.is_empty() {
427 return 0;
428 }
429 let (mut best_idx, mut best_sum) = (0u32, f32::MAX);
430 for i in 0..vectors.len() {
431 let sum: f32 = (0..vectors.len())
432 .map(|j| l2_sq(&vectors[i], &vectors[j]))
433 .sum();
434 if sum < best_sum {
435 best_sum = sum;
436 best_idx = i as u32;
437 }
438 }
439 best_idx
440 }
441}
442
443fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
445 a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum()
446}
447
448#[derive(Debug, Clone, PartialEq)]
449struct OrdF32Pair(f32, u32);
450impl Eq for OrdF32Pair {}
451impl PartialOrd for OrdF32Pair {
452 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
453 Some(self.cmp(other))
454 }
455}
456impl Ord for OrdF32Pair {
457 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
458 self.0
459 .partial_cmp(&other.0)
460 .unwrap_or(std::cmp::Ordering::Equal)
461 .then(self.1.cmp(&other.1))
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 fn make_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
470 (0..n)
471 .map(|i| (0..dim).map(|d| (i * dim + d) as f32).collect())
472 .collect()
473 }
474 fn default_cfg(r: usize, l: usize) -> VamanaConfig {
475 VamanaConfig {
476 max_degree: r,
477 search_list_size: l,
478 ..Default::default()
479 }
480 }
481
482 #[test]
483 fn build_graph_basic() {
484 let g = VamanaGraph::build(make_vecs(10, 4), default_cfg(4, 8)).unwrap();
485 assert_eq!(g.vectors.len(), 10);
486 for nb in &g.neighbors {
487 assert!(nb.len() <= 4);
488 }
489 }
490
491 #[test]
492 fn search_accuracy() {
493 let mut v = make_vecs(20, 4);
494 v.push(vec![0.1, 0.1, 0.1, 0.1]);
495 let g = VamanaGraph::build(v, default_cfg(8, 30)).unwrap();
496 let r = g.search(&[0.0; 4], 3);
497 assert!(r.iter().any(|&(id, _)| id == 20));
498 }
499
500 #[test]
501 fn robust_pruning_limits_degree() {
502 let g = VamanaGraph::build(make_vecs(50, 4), default_cfg(5, 16)).unwrap();
503 for nb in &g.neighbors {
504 assert!(nb.len() <= 5);
505 }
506 }
507
508 #[test]
509 fn disk_layout_roundtrip() {
510 let v = make_vecs(10, 4);
511 let g = VamanaGraph::build(v.clone(), VamanaConfig::default()).unwrap();
512 let d = DiskIndex::from_graph(&g, 16);
513 for i in 0..10 {
514 assert_eq!(d.nodes[i].node_id, i as u32);
515 assert_eq!(d.nodes[i].vector, v[i]);
516 assert_eq!(d.nodes[i].neighbors, g.neighbors[i]);
517 }
518 }
519
520 #[test]
521 fn page_cache_hits_and_misses() {
522 let mut c = PageCache::new(2);
523 assert!(!c.get(0));
524 c.insert(0);
525 assert!(c.get(0));
526 c.insert(1);
527 c.insert(2); assert!(!c.get(0));
529 assert!(c.get(1));
530 }
531
532 #[test]
533 fn cache_hit_rate() {
534 let mut c = PageCache::new(4);
535 c.insert(0);
536 c.insert(1);
537 assert!(c.get(0));
538 assert!(c.get(1));
539 assert!(!c.get(2));
540 assert!((c.cache_hit_rate() - 2.0 / 3.0).abs() < 1e-6);
541 }
542
543 #[test]
544 fn filtered_search() {
545 let mut v = make_vecs(15, 4);
546 v.push(vec![0.1; 4]);
547 let g = VamanaGraph::build(v, default_cfg(8, 20)).unwrap();
548 let mut d = DiskIndex::from_graph(&g, 32);
549 let r = d.search_with_filter(&[0.0; 4], |id| id % 2 == 0, 5);
550 for &(id, _) in &r {
551 assert_eq!(id % 2, 0);
552 }
553 }
554
555 #[test]
556 fn medoid_selection() {
557 let v = vec![
558 vec![0.0, 0.0],
559 vec![1.0, 0.0],
560 vec![0.0, 1.0],
561 vec![0.5, 0.5],
562 ];
563 assert_eq!(MedoidFinder::find_medoid(&v), 3);
564 }
565
566 #[test]
567 fn empty_dataset() {
568 let g = VamanaGraph::build(vec![], VamanaConfig::default()).unwrap();
569 assert!(g.vectors.is_empty());
570 assert!(g.search(&[1.0, 2.0], 5).is_empty());
571 }
572
573 #[test]
574 fn single_vector() {
575 let g = VamanaGraph::build(vec![vec![1.0, 2.0, 3.0]], VamanaConfig::default()).unwrap();
576 assert!(g.neighbors[0].is_empty());
577 let r = g.search(&[1.0, 2.0, 3.0], 1);
578 assert_eq!(r.len(), 1);
579 assert_eq!(r[0].0, 0);
580 }
581
582 #[test]
583 fn io_stats_tracking() {
584 let g = VamanaGraph::build(make_vecs(10, 4), default_cfg(4, 10)).unwrap();
585 let mut d = DiskIndex::from_graph(&g, 2);
586 let (_, s) = d.search_disk(&[0.0; 4], 3, 10);
587 assert!(s.pages_read > 0);
588 assert_eq!(s.bytes_read, s.pages_read * 4096);
589 }
590
591 #[test]
592 fn disk_search_sorted_results() {
593 let g = VamanaGraph::build(make_vecs(20, 4), default_cfg(8, 20)).unwrap();
594 let mut d = DiskIndex::from_graph(&g, 32);
595 let (r, s) = d.search_disk(&[0.0; 4], 5, 20);
596 assert_eq!(r.len(), 5);
597 for w in r.windows(2) {
598 assert!(w[0].1 <= w[1].1);
599 }
600 assert!(s.pages_read + s.cache_hits > 0);
601 }
602
603 #[test]
604 fn config_validation() {
605 assert!(VamanaConfig {
606 max_degree: 0,
607 ..Default::default()
608 }
609 .validate()
610 .is_err());
611 assert!(VamanaConfig {
612 alpha: 0.5,
613 ..Default::default()
614 }
615 .validate()
616 .is_err());
617 assert!(VamanaConfig::default().validate().is_ok());
618 }
619}