1use crate::distance::{l2_squared, FlatVectors, VisitedSet};
4use crate::error::{DiskAnnError, Result};
5use crate::graph::VamanaGraph;
6use crate::pq::ProductQuantizer;
7use memmap2::{Mmap, MmapOptions};
8use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::{BufWriter, Write};
11use std::path::{Path, PathBuf};
12
13#[derive(Debug, Clone)]
15pub struct SearchResult {
16 pub id: String,
17 pub distance: f32,
18}
19
20#[derive(Debug, Clone)]
22pub struct DiskAnnConfig {
23 pub dim: usize,
25 pub max_degree: usize,
27 pub build_beam: usize,
29 pub search_beam: usize,
31 pub alpha: f32,
33 pub pq_subspaces: usize,
35 pub pq_iterations: usize,
37 pub storage_path: Option<PathBuf>,
39}
40
41impl Default for DiskAnnConfig {
42 fn default() -> Self {
43 Self {
44 dim: 128,
45 max_degree: 64,
46 build_beam: 128,
47 search_beam: 64,
48 alpha: 1.2,
49 pq_subspaces: 0,
50 pq_iterations: 10,
51 storage_path: None,
52 }
53 }
54}
55
56pub struct DiskAnnIndex {
58 config: DiskAnnConfig,
59 vectors: FlatVectors,
61 id_map: Vec<String>,
63 id_reverse: HashMap<String, u32>,
65 graph: Option<VamanaGraph>,
67 pq: Option<ProductQuantizer>,
69 pq_codes: Vec<Vec<u8>>,
71 built: bool,
73 visited: Option<VisitedSet>,
75 mmap: Option<Mmap>,
77}
78
79impl DiskAnnIndex {
80 pub fn new(config: DiskAnnConfig) -> Self {
82 let dim = config.dim;
83 Self {
84 config,
85 vectors: FlatVectors::new(dim),
86 id_map: Vec::new(),
87 id_reverse: HashMap::new(),
88 graph: None,
89 pq: None,
90 pq_codes: Vec::new(),
91 built: false,
92 visited: None,
93 mmap: None,
94 }
95 }
96
97 pub fn insert(&mut self, id: String, vector: Vec<f32>) -> Result<()> {
99 if vector.len() != self.config.dim {
100 return Err(DiskAnnError::DimensionMismatch {
101 expected: self.config.dim,
102 actual: vector.len(),
103 });
104 }
105 if self.id_reverse.contains_key(&id) {
106 return Err(DiskAnnError::InvalidConfig(format!("Duplicate ID: {id}")));
107 }
108
109 let idx = self.vectors.len() as u32;
110 self.id_reverse.insert(id.clone(), idx);
111 self.id_map.push(id);
112 self.vectors.push(&vector);
113 self.built = false;
114 Ok(())
115 }
116
117 pub fn insert_batch(&mut self, entries: Vec<(String, Vec<f32>)>) -> Result<()> {
119 for (id, vector) in entries {
120 self.insert(id, vector)?;
121 }
122 Ok(())
123 }
124
125 pub fn build(&mut self) -> Result<()> {
127 let n = self.vectors.len();
128 if n == 0 {
129 return Err(DiskAnnError::Empty);
130 }
131
132 if self.config.pq_subspaces > 0 {
134 let vecs: Vec<Vec<f32>> = (0..n)
136 .map(|i| self.vectors.get(i).to_vec())
137 .collect();
138 let mut pq = ProductQuantizer::new(self.config.dim, self.config.pq_subspaces)?;
139 pq.train(&vecs, self.config.pq_iterations)?;
140
141 self.pq_codes = vecs
142 .iter()
143 .map(|v| pq.encode(v))
144 .collect::<Result<Vec<_>>>()?;
145
146 self.pq = Some(pq);
147 }
148
149 let mut graph = VamanaGraph::new(
151 n,
152 self.config.max_degree,
153 self.config.build_beam,
154 self.config.alpha,
155 );
156 graph.build(&self.vectors)?;
157 self.graph = Some(graph);
158
159 self.visited = Some(VisitedSet::new(n));
161 self.built = true;
162
163 if let Some(ref path) = self.config.storage_path {
164 self.save(path)?;
165 }
166
167 Ok(())
168 }
169
170 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
172 if !self.built {
173 return Err(DiskAnnError::NotBuilt);
174 }
175 if query.len() != self.config.dim {
176 return Err(DiskAnnError::DimensionMismatch {
177 expected: self.config.dim,
178 actual: query.len(),
179 });
180 }
181
182 let graph = self.graph.as_ref().unwrap();
183 let beam = self.config.search_beam.max(k);
184
185 let (candidates, _) = graph.greedy_search(&self.vectors, query, beam);
186
187 let mut scored: Vec<(u32, f32)> = candidates
189 .into_iter()
190 .map(|id| (id, l2_squared(self.vectors.get(id as usize), query)))
191 .collect();
192 scored.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
193
194 Ok(scored
195 .into_iter()
196 .take(k)
197 .map(|(id, dist)| SearchResult {
198 id: self.id_map[id as usize].clone(),
199 distance: dist,
200 })
201 .collect())
202 }
203
204 pub fn count(&self) -> usize {
206 self.vectors.len()
207 }
208
209 pub fn delete(&mut self, id: &str) -> Result<bool> {
211 if let Some(&idx) = self.id_reverse.get(id) {
212 self.vectors.zero_out(idx as usize);
213 self.id_reverse.remove(id);
214 Ok(true)
215 } else {
216 Ok(false)
217 }
218 }
219
220 pub fn save(&self, dir: &Path) -> Result<()> {
222 fs::create_dir_all(dir)?;
223
224 let vec_path = dir.join("vectors.bin");
226 let mut f = BufWriter::new(File::create(&vec_path)?);
227 let n = self.vectors.len() as u64;
228 let dim = self.config.dim as u64;
229 f.write_all(&n.to_le_bytes())?;
230 f.write_all(&dim.to_le_bytes())?;
231 let byte_slice = unsafe {
233 std::slice::from_raw_parts(
234 self.vectors.data.as_ptr() as *const u8,
235 self.vectors.data.len() * 4,
236 )
237 };
238 f.write_all(byte_slice)?;
239 f.flush()?;
240
241 let graph_path = dir.join("graph.bin");
243 let mut f = BufWriter::new(File::create(&graph_path)?);
244 if let Some(ref graph) = self.graph {
245 f.write_all(&(graph.medoid as u64).to_le_bytes())?;
246 f.write_all(&(graph.neighbors.len() as u64).to_le_bytes())?;
247 for neighbors in &graph.neighbors {
248 f.write_all(&(neighbors.len() as u32).to_le_bytes())?;
249 for &n in neighbors {
250 f.write_all(&n.to_le_bytes())?;
251 }
252 }
253 }
254 f.flush()?;
255
256 let ids_path = dir.join("ids.json");
258 let ids_json = serde_json::to_string(&self.id_map)
259 .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
260 fs::write(&ids_path, ids_json)?;
261
262 if let Some(ref pq) = self.pq {
264 let pq_path = dir.join("pq.bin");
265 let pq_bytes = bincode::encode_to_vec(pq, bincode::config::standard())
266 .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
267 fs::write(&pq_path, pq_bytes)?;
268
269 let codes_path = dir.join("pq_codes.bin");
271 let mut f = BufWriter::new(File::create(&codes_path)?);
272 for codes in &self.pq_codes {
273 f.write_all(codes)?;
274 }
275 f.flush()?;
276 }
277
278 let config_path = dir.join("config.json");
280 let config_json = serde_json::json!({
281 "dim": self.config.dim,
282 "max_degree": self.config.max_degree,
283 "build_beam": self.config.build_beam,
284 "search_beam": self.config.search_beam,
285 "alpha": self.config.alpha,
286 "pq_subspaces": self.config.pq_subspaces,
287 "count": self.vectors.len(),
288 "built": self.built,
289 });
290 fs::write(&config_path, serde_json::to_string_pretty(&config_json).unwrap())?;
291
292 Ok(())
293 }
294
295 pub fn load(dir: &Path) -> Result<Self> {
297 let config_json: serde_json::Value =
299 serde_json::from_str(&fs::read_to_string(dir.join("config.json"))?)
300 .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
301
302 let dim = config_json["dim"].as_u64().unwrap() as usize;
303 let max_degree = config_json["max_degree"].as_u64().unwrap() as usize;
304 let build_beam = config_json["build_beam"].as_u64().unwrap() as usize;
305 let search_beam = config_json["search_beam"].as_u64().unwrap() as usize;
306 let alpha = config_json["alpha"].as_f64().unwrap() as f32;
307 let pq_subspaces = config_json["pq_subspaces"].as_u64().unwrap_or(0) as usize;
308
309 let config = DiskAnnConfig {
310 dim,
311 max_degree,
312 build_beam,
313 search_beam,
314 alpha,
315 pq_subspaces,
316 storage_path: Some(dir.to_path_buf()),
317 ..Default::default()
318 };
319
320 let vec_file = File::open(dir.join("vectors.bin"))?;
322 let mmap = unsafe { MmapOptions::new().map(&vec_file)? };
323
324 let n = u64::from_le_bytes(mmap[0..8].try_into().unwrap()) as usize;
325 let file_dim = u64::from_le_bytes(mmap[8..16].try_into().unwrap()) as usize;
326 assert_eq!(file_dim, dim);
327
328 let data_start = 16;
330 let total_floats = n * dim;
331 let mut flat_data = Vec::with_capacity(total_floats);
332 let byte_slice = &mmap[data_start..data_start + total_floats * 4];
333 for chunk in byte_slice.chunks_exact(4) {
335 flat_data.push(f32::from_le_bytes(chunk.try_into().unwrap()));
336 }
337 let vectors = FlatVectors {
338 data: flat_data,
339 dim,
340 count: n,
341 };
342
343 let ids_json = fs::read_to_string(dir.join("ids.json"))?;
345 let id_map: Vec<String> = serde_json::from_str(&ids_json)
346 .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
347
348 let mut id_reverse = HashMap::new();
349 for (i, id) in id_map.iter().enumerate() {
350 id_reverse.insert(id.clone(), i as u32);
351 }
352
353 let graph_bytes = fs::read(dir.join("graph.bin"))?;
355 let medoid = u64::from_le_bytes(graph_bytes[0..8].try_into().unwrap()) as u32;
356 let graph_n = u64::from_le_bytes(graph_bytes[8..16].try_into().unwrap()) as usize;
357
358 let mut neighbors = Vec::with_capacity(graph_n);
359 let mut offset = 16;
360 for _ in 0..graph_n {
361 let deg = u32::from_le_bytes(graph_bytes[offset..offset + 4].try_into().unwrap()) as usize;
362 offset += 4;
363 let mut nbrs = Vec::with_capacity(deg);
364 for _ in 0..deg {
365 let nbr = u32::from_le_bytes(graph_bytes[offset..offset + 4].try_into().unwrap());
366 offset += 4;
367 nbrs.push(nbr);
368 }
369 neighbors.push(nbrs);
370 }
371
372 let graph = VamanaGraph {
373 neighbors,
374 medoid,
375 max_degree,
376 build_beam,
377 alpha,
378 };
379
380 let pq_path = dir.join("pq.bin");
382 let (pq, pq_codes) = if pq_path.exists() {
383 let pq_bytes = fs::read(&pq_path)?;
384 let (pq, _): (ProductQuantizer, usize) =
385 bincode::decode_from_slice(&pq_bytes, bincode::config::standard())
386 .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
387
388 let codes_bytes = fs::read(dir.join("pq_codes.bin"))?;
389 let m = pq.m;
390 let mut codes = Vec::with_capacity(n);
391 for i in 0..n {
392 codes.push(codes_bytes[i * m..(i + 1) * m].to_vec());
393 }
394 (Some(pq), codes)
395 } else {
396 (None, Vec::new())
397 };
398
399 Ok(Self {
400 config,
401 vectors,
402 id_map,
403 id_reverse,
404 graph: Some(graph),
405 pq,
406 pq_codes,
407 built: true,
408 visited: Some(VisitedSet::new(n)),
409 mmap: Some(mmap),
410 })
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use tempfile::tempdir;
418
419 fn random_vectors(n: usize, dim: usize) -> Vec<(String, Vec<f32>)> {
420 use rand::prelude::*;
421 let mut rng = rand::thread_rng();
422 (0..n)
423 .map(|i| {
424 let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
425 (format!("vec-{i}"), v)
426 })
427 .collect()
428 }
429
430 fn random_data(n: usize, dim: usize) -> Vec<(String, Vec<f32>)> {
431 random_vectors(n, dim)
432 }
433
434 #[test]
435 fn test_diskann_basic() {
436 let mut index = DiskAnnIndex::new(DiskAnnConfig {
437 dim: 32,
438 max_degree: 16,
439 build_beam: 32,
440 search_beam: 32,
441 alpha: 1.2,
442 ..Default::default()
443 });
444
445 let data = random_vectors(500, 32);
446 let query = data[42].1.clone();
447
448 index.insert_batch(data).unwrap();
449 index.build().unwrap();
450
451 let results = index.search(&query, 5).unwrap();
452 assert!(!results.is_empty());
453 assert_eq!(results[0].id, "vec-42"); assert!(results[0].distance < 1e-6); }
456
457 #[test]
458 fn test_diskann_with_pq() {
459 let mut index = DiskAnnIndex::new(DiskAnnConfig {
460 dim: 32,
461 max_degree: 16,
462 build_beam: 32,
463 search_beam: 32,
464 alpha: 1.2,
465 pq_subspaces: 4,
466 pq_iterations: 5,
467 ..Default::default()
468 });
469
470 let data = random_vectors(200, 32);
471 let query = data[10].1.clone();
472
473 index.insert_batch(data).unwrap();
474 index.build().unwrap();
475
476 let results = index.search(&query, 5).unwrap();
477 assert_eq!(results[0].id, "vec-10");
478 }
479
480 #[test]
481 fn test_diskann_save_load() {
482 let dir = tempdir().unwrap();
483 let path = dir.path().join("diskann_test");
484
485 let data = random_vectors(100, 16);
486 let query = data[7].1.clone();
487
488 {
490 let mut index = DiskAnnIndex::new(DiskAnnConfig {
491 dim: 16,
492 max_degree: 8,
493 build_beam: 16,
494 search_beam: 16,
495 alpha: 1.2,
496 storage_path: Some(path.clone()),
497 ..Default::default()
498 });
499 index.insert_batch(data).unwrap();
500 index.build().unwrap();
501 }
502
503 let loaded = DiskAnnIndex::load(&path).unwrap();
505 let results = loaded.search(&query, 3).unwrap();
506 assert_eq!(results[0].id, "vec-7");
507 }
508
509 #[test]
510 fn test_recall_at_10() {
511 use rand::prelude::*;
513 let mut rng = rand::thread_rng();
514 let n = 2000;
515 let dim = 64;
516 let k = 10;
517
518 let data: Vec<(String, Vec<f32>)> = (0..n)
519 .map(|i| {
520 let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
521 (format!("v{i}"), v)
522 })
523 .collect();
524
525 let mut index = DiskAnnIndex::new(DiskAnnConfig {
526 dim,
527 max_degree: 32,
528 build_beam: 64,
529 search_beam: 64,
530 alpha: 1.2,
531 ..Default::default()
532 });
533 index.insert_batch(data.clone()).unwrap();
534 index.build().unwrap();
535
536 let num_queries = 50;
538 let mut total_recall = 0.0;
539
540 for _ in 0..num_queries {
541 let qi = rng.gen_range(0..n);
542 let query = &data[qi].1;
543
544 let mut brute: Vec<(usize, f32)> = data
546 .iter()
547 .enumerate()
548 .map(|(i, (_, v))| (i, crate::distance::l2_squared(v, query)))
549 .collect();
550 brute.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
551 let gt: std::collections::HashSet<String> = brute[..k]
552 .iter()
553 .map(|(i, _)| data[*i].0.clone())
554 .collect();
555
556 let results = index.search(query, k).unwrap();
558 let found: std::collections::HashSet<String> =
559 results.iter().map(|r| r.id.clone()).collect();
560
561 let recall = gt.intersection(&found).count() as f64 / k as f64;
562 total_recall += recall;
563 }
564
565 let avg_recall = total_recall / num_queries as f64;
566 println!("Recall@{k} = {avg_recall:.3} (n={n}, dim={dim}, queries={num_queries})");
567 assert!(
568 avg_recall >= 0.85,
569 "Recall@{k} = {avg_recall:.3}, expected >= 0.85"
570 );
571 }
572
573 #[test]
574 fn test_dimension_mismatch() {
575 let mut index = DiskAnnIndex::new(DiskAnnConfig {
576 dim: 16,
577 ..Default::default()
578 });
579
580 let result = index.insert("bad".to_string(), vec![1.0; 32]);
582 assert!(result.is_err());
583
584 index.insert("ok".to_string(), vec![1.0; 16]).unwrap();
586 index.build().unwrap();
587 let result = index.search(&[1.0; 32], 1);
588 assert!(result.is_err());
589 }
590
591 #[test]
592 fn test_duplicate_id_rejected() {
593 let mut index = DiskAnnIndex::new(DiskAnnConfig {
594 dim: 4,
595 ..Default::default()
596 });
597 index.insert("a".to_string(), vec![1.0; 4]).unwrap();
598 let result = index.insert("a".to_string(), vec![2.0; 4]);
599 assert!(result.is_err());
600 }
601
602 #[test]
603 fn test_search_before_build_fails() {
604 let mut index = DiskAnnIndex::new(DiskAnnConfig {
605 dim: 4,
606 ..Default::default()
607 });
608 index.insert("a".to_string(), vec![1.0; 4]).unwrap();
609 let result = index.search(&[1.0; 4], 1);
610 assert!(result.is_err());
611 }
612
613 #[test]
614 fn test_scale_5k() {
615 use std::time::Instant;
617 use rand::prelude::*;
618 let mut rng = rand::thread_rng();
619
620 let n = 5000;
621 let dim = 128;
622 let data: Vec<(String, Vec<f32>)> = (0..n)
623 .map(|i| {
624 let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
625 (format!("v{i}"), v)
626 })
627 .collect();
628
629 let mut index = DiskAnnIndex::new(DiskAnnConfig {
630 dim,
631 max_degree: 48,
632 build_beam: 96,
633 search_beam: 48,
634 alpha: 1.2,
635 ..Default::default()
636 });
637 index.insert_batch(data.clone()).unwrap();
638
639 let t0 = Instant::now();
640 index.build().unwrap();
641 let build_ms = t0.elapsed().as_millis();
642 println!("Build {n} vectors ({dim}d): {build_ms}ms");
643
644 let query = &data[0].1;
646 let t0 = Instant::now();
647 let iters = 100;
648 for _ in 0..iters {
649 let _ = index.search(query, 10).unwrap();
650 }
651 let search_us = t0.elapsed().as_micros() / iters;
652 println!("Search latency (k=10): {search_us}µs avg over {iters} queries");
653
654 assert!(search_us < 10_000, "Search took {search_us}µs, expected <10ms");
655 }
656}