1use crate::error::{Result, RuvectorError};
20use serde::{Deserialize, Serialize};
21use std::collections::{BinaryHeap, HashMap, HashSet};
22use std::cmp::Reverse;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct VamanaConfig {
27 pub max_degree: usize,
29 pub search_list_size: usize,
31 pub alpha: f32,
33 pub num_build_threads: usize,
35 pub ssd_page_size: usize,
37}
38
39impl Default for VamanaConfig {
40 fn default() -> Self {
41 Self { max_degree: 32, search_list_size: 64, alpha: 1.2, num_build_threads: 1, ssd_page_size: 4096 }
42 }
43}
44
45impl VamanaConfig {
46 pub fn validate(&self) -> Result<()> {
48 if self.max_degree == 0 {
49 return Err(RuvectorError::InvalidParameter("max_degree must be > 0".into()));
50 }
51 if self.search_list_size < 1 {
52 return Err(RuvectorError::InvalidParameter("search_list_size must be >= 1".into()));
53 }
54 if self.alpha < 1.0 {
55 return Err(RuvectorError::InvalidParameter("alpha must be >= 1.0".into()));
56 }
57 Ok(())
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct VamanaGraph {
64 pub neighbors: Vec<Vec<u32>>,
66 pub vectors: Vec<Vec<f32>>,
68 pub medoid: u32,
70 pub config: VamanaConfig,
72}
73
74impl VamanaGraph {
75 pub fn build(vectors: Vec<Vec<f32>>, config: VamanaConfig) -> Result<Self> {
77 config.validate()?;
78 let n = vectors.len();
79 if n == 0 {
80 return Ok(Self { neighbors: vec![], vectors: vec![], medoid: 0, config });
81 }
82 let dim = vectors[0].len();
83 for v in &vectors {
84 if v.len() != dim {
85 return Err(RuvectorError::DimensionMismatch { expected: dim, actual: v.len() });
86 }
87 }
88 let medoid = MedoidFinder::find_medoid(&vectors);
89 let mut graph = Self { neighbors: vec![vec![]; n], vectors, medoid, config };
90 for i in 0..n {
92 let mut nb = Vec::new();
93 for j in 0..n.min(graph.config.max_degree + 1) {
94 if j != i { nb.push(j as u32); }
95 if nb.len() >= graph.config.max_degree { break; }
96 }
97 graph.neighbors[i] = nb;
98 }
99 for i in 0..n {
101 let query = graph.vectors[i].clone();
102 let (cands, _) = graph.greedy_search_internal(&query, graph.config.search_list_size);
103 let mut cset: Vec<u32> = cands.into_iter().filter(|&c| c != i as u32).collect();
104 for &nb in &graph.neighbors[i] {
105 if !cset.contains(&nb) { cset.push(nb); }
106 }
107 let pruned = graph.robust_prune(i as u32, &cset);
108 graph.neighbors[i] = pruned.clone();
109 for &nb in &pruned {
110 let ni = nb as usize;
111 if !graph.neighbors[ni].contains(&(i as u32)) {
112 graph.neighbors[ni].push(i as u32);
113 if graph.neighbors[ni].len() > graph.config.max_degree {
114 let nbs = graph.neighbors[ni].clone();
115 graph.neighbors[ni] = graph.robust_prune(nb, &nbs);
116 }
117 }
118 }
119 }
120 Ok(graph)
121 }
122
123 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(u32, f32)> {
125 if self.vectors.is_empty() { return vec![]; }
126 let beam = self.config.search_list_size.max(top_k);
127 let (ids, dists) = self.greedy_search_internal(query, beam);
128 ids.into_iter().zip(dists).take(top_k).collect()
129 }
130
131 fn greedy_search_internal(&self, query: &[f32], list_size: usize) -> (Vec<u32>, Vec<f32>) {
132 let mut visited = HashSet::new();
133 let mut frontier: BinaryHeap<Reverse<OrdF32Pair>> = BinaryHeap::new();
134 let mut results: Vec<(f32, u32)> = Vec::new();
135 let start = self.medoid;
136 let d = l2_sq(&self.vectors[start as usize], query);
137 frontier.push(Reverse(OrdF32Pair(d, start)));
138 visited.insert(start);
139 results.push((d, start));
140 while let Some(Reverse(OrdF32Pair(_, node))) = frontier.pop() {
141 for &nb in &self.neighbors[node as usize] {
142 if visited.insert(nb) {
143 let dist = l2_sq(&self.vectors[nb as usize], query);
144 results.push((dist, nb));
145 frontier.push(Reverse(OrdF32Pair(dist, nb)));
146 }
147 }
148 if results.len() > list_size * 2 {
149 results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
150 results.truncate(list_size);
151 }
152 }
153 results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
154 results.truncate(list_size);
155 (results.iter().map(|r| r.1).collect(), results.iter().map(|r| r.0).collect())
156 }
157
158 fn robust_prune(&self, node_id: u32, candidates: &[u32]) -> Vec<u32> {
160 let nv = &self.vectors[node_id as usize];
161 let mut scored: Vec<(f32, u32)> = candidates.iter()
162 .filter(|&&c| c != node_id)
163 .map(|&c| (l2_sq(nv, &self.vectors[c as usize]), c))
164 .collect();
165 scored.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
166 let mut sel: Vec<u32> = Vec::new();
167 for (d2n, cand) in scored {
168 if sel.len() >= self.config.max_degree { break; }
169 let cv = &self.vectors[cand as usize];
170 if sel.iter().all(|&s| d2n <= self.config.alpha * l2_sq(&self.vectors[s as usize], cv)) {
171 sel.push(cand);
172 }
173 }
174 sel
175 }
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct DiskNode {
181 pub node_id: u32,
182 pub neighbors: Vec<u32>,
183 pub vector: Vec<f32>,
184}
185
186#[derive(Debug, Clone, Default)]
188pub struct IOStats {
189 pub pages_read: usize,
190 pub bytes_read: usize,
191 pub cache_hits: usize,
192}
193
194#[derive(Debug)]
196pub struct DiskIndex {
197 nodes: Vec<DiskNode>,
198 page_size: usize,
199 medoid: u32,
200 cache: PageCache,
201}
202
203impl DiskIndex {
204 pub fn from_graph(graph: &VamanaGraph, cache_size_pages: usize) -> Self {
206 let nodes = (0..graph.vectors.len()).map(|i| DiskNode {
207 node_id: i as u32, neighbors: graph.neighbors[i].clone(), vector: graph.vectors[i].clone(),
208 }).collect();
209 Self { nodes, page_size: graph.config.ssd_page_size, medoid: graph.medoid, cache: PageCache::new(cache_size_pages) }
210 }
211
212 pub fn search_disk(&mut self, query: &[f32], top_k: usize, beam_width: usize) -> (Vec<(u32, f32)>, IOStats) {
214 let mut stats = IOStats::default();
215 if self.nodes.is_empty() { return (vec![], stats); }
216 let mut visited = HashSet::new();
217 let mut frontier: BinaryHeap<Reverse<OrdF32Pair>> = BinaryHeap::new();
218 let mut results: Vec<(f32, u32)> = Vec::new();
219 let start = self.medoid;
220 let d = l2_sq(&self.read_node(start, &mut stats).vector.clone(), query);
221 frontier.push(Reverse(OrdF32Pair(d, start)));
222 visited.insert(start);
223 results.push((d, start));
224 while let Some(Reverse(OrdF32Pair(_, cur))) = frontier.pop() {
225 let nbs = self.read_node(cur, &mut stats).neighbors.clone();
226 for nb in nbs {
227 if visited.insert(nb) {
228 let v = self.read_node(nb, &mut stats).vector.clone();
229 let dist = l2_sq(&v, query);
230 results.push((dist, nb));
231 frontier.push(Reverse(OrdF32Pair(dist, nb)));
232 }
233 }
234 if results.len() > beam_width * 2 {
235 results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
236 results.truncate(beam_width);
237 }
238 }
239 results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
240 results.truncate(top_k);
241 (results.iter().map(|r| (r.1, r.0)).collect(), stats)
242 }
243
244 fn read_node(&mut self, node_id: u32, stats: &mut IOStats) -> &DiskNode {
245 let page_id = node_id as usize;
246 if self.cache.get(page_id) { stats.cache_hits += 1; }
247 else { stats.pages_read += 1; stats.bytes_read += self.page_size; self.cache.insert(page_id); }
248 &self.nodes[node_id as usize]
249 }
250
251 pub fn search_with_filter<F>(&mut self, query: &[f32], filter_fn: F, top_k: usize) -> Vec<(u32, f32)>
254 where F: Fn(u32) -> bool {
255 if self.nodes.is_empty() { return vec![]; }
256 let mut visited = HashSet::new();
257 let mut frontier: BinaryHeap<Reverse<OrdF32Pair>> = BinaryHeap::new();
258 let mut results: Vec<(f32, u32)> = Vec::new();
259 let mut io = IOStats::default();
260 let start = self.medoid;
261 let d = l2_sq(&self.read_node(start, &mut io).vector.clone(), query);
262 frontier.push(Reverse(OrdF32Pair(d, start)));
263 visited.insert(start);
264 if filter_fn(start) { results.push((d, start)); }
265 while let Some(Reverse(OrdF32Pair(_, cur))) = frontier.pop() {
266 let nbs = self.read_node(cur, &mut io).neighbors.clone();
267 for nb in nbs {
268 if visited.insert(nb) {
269 let v = self.read_node(nb, &mut io).vector.clone();
270 let dist = l2_sq(&v, query);
271 frontier.push(Reverse(OrdF32Pair(dist, nb)));
272 if filter_fn(nb) { results.push((dist, nb)); }
273 }
274 }
275 }
276 results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
277 results.truncate(top_k);
278 results.iter().map(|r| (r.1, r.0)).collect()
279 }
280}
281
282#[derive(Debug)]
284pub struct PageCache {
285 capacity: usize,
286 clock: u64,
287 entries: HashMap<usize, u64>,
288 total_hits: u64,
289 total_accesses: u64,
290}
291
292impl PageCache {
293 pub fn new(capacity: usize) -> Self {
294 Self { capacity, clock: 0, entries: HashMap::new(), total_hits: 0, total_accesses: 0 }
295 }
296
297 pub fn get(&mut self, page_id: usize) -> bool {
299 self.total_accesses += 1;
300 self.clock += 1;
301 if let Some(ts) = self.entries.get_mut(&page_id) {
302 *ts = self.clock; self.total_hits += 1; true
303 } else { false }
304 }
305
306 pub fn insert(&mut self, page_id: usize) {
308 if self.capacity == 0 { return; }
309 if self.entries.len() >= self.capacity {
310 let lru = self.entries.iter().min_by_key(|&(_, ts)| *ts).map(|(&k, _)| k);
311 if let Some(k) = lru { self.entries.remove(&k); }
312 }
313 self.clock += 1;
314 self.entries.insert(page_id, self.clock);
315 }
316
317 pub fn cache_hit_rate(&self) -> f64 {
319 if self.total_accesses == 0 { 0.0 } else { self.total_hits as f64 / self.total_accesses as f64 }
320 }
321}
322
323pub struct MedoidFinder;
325
326impl MedoidFinder {
327 pub fn find_medoid(vectors: &[Vec<f32>]) -> u32 {
328 if vectors.is_empty() { return 0; }
329 let (mut best_idx, mut best_sum) = (0u32, f32::MAX);
330 for i in 0..vectors.len() {
331 let sum: f32 = (0..vectors.len()).map(|j| l2_sq(&vectors[i], &vectors[j])).sum();
332 if sum < best_sum { best_sum = sum; best_idx = i as u32; }
333 }
334 best_idx
335 }
336}
337
338fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
340 a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum()
341}
342
343#[derive(Debug, Clone, PartialEq)]
344struct OrdF32Pair(f32, u32);
345impl Eq for OrdF32Pair {}
346impl PartialOrd for OrdF32Pair {
347 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { Some(self.cmp(other)) }
348}
349impl Ord for OrdF32Pair {
350 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
351 self.0.partial_cmp(&other.0).unwrap_or(std::cmp::Ordering::Equal).then(self.1.cmp(&other.1))
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 fn make_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
360 (0..n).map(|i| (0..dim).map(|d| (i * dim + d) as f32).collect()).collect()
361 }
362 fn default_cfg(r: usize, l: usize) -> VamanaConfig {
363 VamanaConfig { max_degree: r, search_list_size: l, ..Default::default() }
364 }
365
366 #[test]
367 fn build_graph_basic() {
368 let g = VamanaGraph::build(make_vecs(10, 4), default_cfg(4, 8)).unwrap();
369 assert_eq!(g.vectors.len(), 10);
370 for nb in &g.neighbors { assert!(nb.len() <= 4); }
371 }
372
373 #[test]
374 fn search_accuracy() {
375 let mut v = make_vecs(20, 4);
376 v.push(vec![0.1, 0.1, 0.1, 0.1]);
377 let g = VamanaGraph::build(v, default_cfg(8, 30)).unwrap();
378 let r = g.search(&[0.0; 4], 3);
379 assert!(r.iter().any(|&(id, _)| id == 20));
380 }
381
382 #[test]
383 fn robust_pruning_limits_degree() {
384 let g = VamanaGraph::build(make_vecs(50, 4), default_cfg(5, 16)).unwrap();
385 for nb in &g.neighbors { assert!(nb.len() <= 5); }
386 }
387
388 #[test]
389 fn disk_layout_roundtrip() {
390 let v = make_vecs(10, 4);
391 let g = VamanaGraph::build(v.clone(), VamanaConfig::default()).unwrap();
392 let d = DiskIndex::from_graph(&g, 16);
393 for i in 0..10 {
394 assert_eq!(d.nodes[i].node_id, i as u32);
395 assert_eq!(d.nodes[i].vector, v[i]);
396 assert_eq!(d.nodes[i].neighbors, g.neighbors[i]);
397 }
398 }
399
400 #[test]
401 fn page_cache_hits_and_misses() {
402 let mut c = PageCache::new(2);
403 assert!(!c.get(0));
404 c.insert(0);
405 assert!(c.get(0));
406 c.insert(1);
407 c.insert(2); assert!(!c.get(0));
409 assert!(c.get(1));
410 }
411
412 #[test]
413 fn cache_hit_rate() {
414 let mut c = PageCache::new(4);
415 c.insert(0); c.insert(1);
416 assert!(c.get(0)); assert!(c.get(1)); assert!(!c.get(2));
417 assert!((c.cache_hit_rate() - 2.0 / 3.0).abs() < 1e-6);
418 }
419
420 #[test]
421 fn filtered_search() {
422 let mut v = make_vecs(15, 4);
423 v.push(vec![0.1; 4]);
424 let g = VamanaGraph::build(v, default_cfg(8, 20)).unwrap();
425 let mut d = DiskIndex::from_graph(&g, 32);
426 let r = d.search_with_filter(&[0.0; 4], |id| id % 2 == 0, 5);
427 for &(id, _) in &r { assert_eq!(id % 2, 0); }
428 }
429
430 #[test]
431 fn medoid_selection() {
432 let v = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
433 assert_eq!(MedoidFinder::find_medoid(&v), 3);
434 }
435
436 #[test]
437 fn empty_dataset() {
438 let g = VamanaGraph::build(vec![], VamanaConfig::default()).unwrap();
439 assert!(g.vectors.is_empty());
440 assert!(g.search(&[1.0, 2.0], 5).is_empty());
441 }
442
443 #[test]
444 fn single_vector() {
445 let g = VamanaGraph::build(vec![vec![1.0, 2.0, 3.0]], VamanaConfig::default()).unwrap();
446 assert!(g.neighbors[0].is_empty());
447 let r = g.search(&[1.0, 2.0, 3.0], 1);
448 assert_eq!(r.len(), 1);
449 assert_eq!(r[0].0, 0);
450 }
451
452 #[test]
453 fn io_stats_tracking() {
454 let g = VamanaGraph::build(make_vecs(10, 4), default_cfg(4, 10)).unwrap();
455 let mut d = DiskIndex::from_graph(&g, 2);
456 let (_, s) = d.search_disk(&[0.0; 4], 3, 10);
457 assert!(s.pages_read > 0);
458 assert_eq!(s.bytes_read, s.pages_read * 4096);
459 }
460
461 #[test]
462 fn disk_search_sorted_results() {
463 let g = VamanaGraph::build(make_vecs(20, 4), default_cfg(8, 20)).unwrap();
464 let mut d = DiskIndex::from_graph(&g, 32);
465 let (r, s) = d.search_disk(&[0.0; 4], 5, 20);
466 assert_eq!(r.len(), 5);
467 for w in r.windows(2) { assert!(w[0].1 <= w[1].1); }
468 assert!(s.pages_read + s.cache_hits > 0);
469 }
470
471 #[test]
472 fn config_validation() {
473 assert!(VamanaConfig { max_degree: 0, ..Default::default() }.validate().is_err());
474 assert!(VamanaConfig { alpha: 0.5, ..Default::default() }.validate().is_err());
475 assert!(VamanaConfig::default().validate().is_ok());
476 }
477}