1use crate::distance::{l2_squared, FlatVectors, VisitedSet};
4use crate::error::{DiskAnnError, Result};
5use crate::graph::VamanaGraph;
6use crate::pq::ProductQuantizer;
7use memmap2::{Mmap, MmapOptions};
8use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::{BufWriter, Write};
11use std::path::{Path, PathBuf};
12
13#[derive(Debug, Clone)]
15pub struct SearchResult {
16 pub id: String,
17 pub distance: f32,
18}
19
20#[derive(Debug, Clone)]
22pub struct DiskAnnConfig {
23 pub dim: usize,
25 pub max_degree: usize,
27 pub build_beam: usize,
29 pub search_beam: usize,
31 pub alpha: f32,
33 pub pq_subspaces: usize,
35 pub pq_iterations: usize,
37 pub storage_path: Option<PathBuf>,
39}
40
41impl Default for DiskAnnConfig {
42 fn default() -> Self {
43 Self {
44 dim: 128,
45 max_degree: 64,
46 build_beam: 128,
47 search_beam: 64,
48 alpha: 1.2,
49 pq_subspaces: 0,
50 pq_iterations: 10,
51 storage_path: None,
52 }
53 }
54}
55
56pub struct DiskAnnIndex {
58 config: DiskAnnConfig,
59 vectors: FlatVectors,
61 id_map: Vec<String>,
63 id_reverse: HashMap<String, u32>,
65 graph: Option<VamanaGraph>,
67 pq: Option<ProductQuantizer>,
69 pq_codes: Vec<Vec<u8>>,
71 built: bool,
73 visited: Option<VisitedSet>,
75 mmap: Option<Mmap>,
77}
78
79impl DiskAnnIndex {
80 pub fn new(config: DiskAnnConfig) -> Self {
82 let dim = config.dim;
83 Self {
84 config,
85 vectors: FlatVectors::new(dim),
86 id_map: Vec::new(),
87 id_reverse: HashMap::new(),
88 graph: None,
89 pq: None,
90 pq_codes: Vec::new(),
91 built: false,
92 visited: None,
93 mmap: None,
94 }
95 }
96
97 pub fn insert(&mut self, id: String, vector: Vec<f32>) -> Result<()> {
99 if vector.len() != self.config.dim {
100 return Err(DiskAnnError::DimensionMismatch {
101 expected: self.config.dim,
102 actual: vector.len(),
103 });
104 }
105 if self.id_reverse.contains_key(&id) {
106 return Err(DiskAnnError::InvalidConfig(format!("Duplicate ID: {id}")));
107 }
108
109 let idx = self.vectors.len() as u32;
110 self.id_reverse.insert(id.clone(), idx);
111 self.id_map.push(id);
112 self.vectors.push(&vector);
113 self.built = false;
114 Ok(())
115 }
116
117 pub fn insert_batch(&mut self, entries: Vec<(String, Vec<f32>)>) -> Result<()> {
119 for (id, vector) in entries {
120 self.insert(id, vector)?;
121 }
122 Ok(())
123 }
124
125 pub fn build(&mut self) -> Result<()> {
127 let n = self.vectors.len();
128 if n == 0 {
129 return Err(DiskAnnError::Empty);
130 }
131
132 if self.config.pq_subspaces > 0 {
134 let vecs: Vec<Vec<f32>> = (0..n).map(|i| self.vectors.get(i).to_vec()).collect();
136 let mut pq = ProductQuantizer::new(self.config.dim, self.config.pq_subspaces)?;
137 pq.train(&vecs, self.config.pq_iterations)?;
138
139 self.pq_codes = vecs
140 .iter()
141 .map(|v| pq.encode(v))
142 .collect::<Result<Vec<_>>>()?;
143
144 self.pq = Some(pq);
145 }
146
147 let mut graph = VamanaGraph::new(
149 n,
150 self.config.max_degree,
151 self.config.build_beam,
152 self.config.alpha,
153 );
154 graph.build(&self.vectors)?;
155 self.graph = Some(graph);
156
157 self.visited = Some(VisitedSet::new(n));
159 self.built = true;
160
161 if let Some(ref path) = self.config.storage_path {
162 self.save(path)?;
163 }
164
165 Ok(())
166 }
167
168 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
170 if !self.built {
171 return Err(DiskAnnError::NotBuilt);
172 }
173 if query.len() != self.config.dim {
174 return Err(DiskAnnError::DimensionMismatch {
175 expected: self.config.dim,
176 actual: query.len(),
177 });
178 }
179
180 let graph = self.graph.as_ref().unwrap();
181 let beam = self.config.search_beam.max(k);
182
183 let (candidates, _) = graph.greedy_search(&self.vectors, query, beam);
184
185 let mut scored: Vec<(u32, f32)> = candidates
187 .into_iter()
188 .map(|id| (id, l2_squared(self.vectors.get(id as usize), query)))
189 .collect();
190 scored.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
191
192 Ok(scored
193 .into_iter()
194 .take(k)
195 .map(|(id, dist)| SearchResult {
196 id: self.id_map[id as usize].clone(),
197 distance: dist,
198 })
199 .collect())
200 }
201
202 pub fn count(&self) -> usize {
204 self.vectors.len()
205 }
206
207 pub fn delete(&mut self, id: &str) -> Result<bool> {
209 if let Some(&idx) = self.id_reverse.get(id) {
210 self.vectors.zero_out(idx as usize);
211 self.id_reverse.remove(id);
212 Ok(true)
213 } else {
214 Ok(false)
215 }
216 }
217
218 pub fn save(&self, dir: &Path) -> Result<()> {
220 fs::create_dir_all(dir)?;
221
222 let vec_path = dir.join("vectors.bin");
224 let mut f = BufWriter::new(File::create(&vec_path)?);
225 let n = self.vectors.len() as u64;
226 let dim = self.config.dim as u64;
227 f.write_all(&n.to_le_bytes())?;
228 f.write_all(&dim.to_le_bytes())?;
229 let byte_slice = unsafe {
231 std::slice::from_raw_parts(
232 self.vectors.data.as_ptr() as *const u8,
233 self.vectors.data.len() * 4,
234 )
235 };
236 f.write_all(byte_slice)?;
237 f.flush()?;
238
239 let graph_path = dir.join("graph.bin");
241 let mut f = BufWriter::new(File::create(&graph_path)?);
242 if let Some(ref graph) = self.graph {
243 f.write_all(&(graph.medoid as u64).to_le_bytes())?;
244 f.write_all(&(graph.neighbors.len() as u64).to_le_bytes())?;
245 for neighbors in &graph.neighbors {
246 f.write_all(&(neighbors.len() as u32).to_le_bytes())?;
247 for &n in neighbors {
248 f.write_all(&n.to_le_bytes())?;
249 }
250 }
251 }
252 f.flush()?;
253
254 let ids_path = dir.join("ids.json");
256 let ids_json = serde_json::to_string(&self.id_map)
257 .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
258 fs::write(&ids_path, ids_json)?;
259
260 if let Some(ref pq) = self.pq {
262 let pq_path = dir.join("pq.bin");
263 let pq_bytes = bincode::encode_to_vec(pq, bincode::config::standard())
264 .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
265 fs::write(&pq_path, pq_bytes)?;
266
267 let codes_path = dir.join("pq_codes.bin");
269 let mut f = BufWriter::new(File::create(&codes_path)?);
270 for codes in &self.pq_codes {
271 f.write_all(codes)?;
272 }
273 f.flush()?;
274 }
275
276 let config_path = dir.join("config.json");
278 let config_json = serde_json::json!({
279 "dim": self.config.dim,
280 "max_degree": self.config.max_degree,
281 "build_beam": self.config.build_beam,
282 "search_beam": self.config.search_beam,
283 "alpha": self.config.alpha,
284 "pq_subspaces": self.config.pq_subspaces,
285 "count": self.vectors.len(),
286 "built": self.built,
287 });
288 fs::write(
289 &config_path,
290 serde_json::to_string_pretty(&config_json).unwrap(),
291 )?;
292
293 Ok(())
294 }
295
296 pub fn load(dir: &Path) -> Result<Self> {
298 let config_json: serde_json::Value =
300 serde_json::from_str(&fs::read_to_string(dir.join("config.json"))?)
301 .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
302
303 let dim = config_json["dim"].as_u64().unwrap() as usize;
304 let max_degree = config_json["max_degree"].as_u64().unwrap() as usize;
305 let build_beam = config_json["build_beam"].as_u64().unwrap() as usize;
306 let search_beam = config_json["search_beam"].as_u64().unwrap() as usize;
307 let alpha = config_json["alpha"].as_f64().unwrap() as f32;
308 let pq_subspaces = config_json["pq_subspaces"].as_u64().unwrap_or(0) as usize;
309
310 let config = DiskAnnConfig {
311 dim,
312 max_degree,
313 build_beam,
314 search_beam,
315 alpha,
316 pq_subspaces,
317 storage_path: Some(dir.to_path_buf()),
318 ..Default::default()
319 };
320
321 let vec_file = File::open(dir.join("vectors.bin"))?;
323 let mmap = unsafe { MmapOptions::new().map(&vec_file)? };
324
325 let n = u64::from_le_bytes(mmap[0..8].try_into().unwrap()) as usize;
326 let file_dim = u64::from_le_bytes(mmap[8..16].try_into().unwrap()) as usize;
327 assert_eq!(file_dim, dim);
328
329 let data_start = 16;
331 let total_floats = n * dim;
332 let mut flat_data = Vec::with_capacity(total_floats);
333 let byte_slice = &mmap[data_start..data_start + total_floats * 4];
334 for chunk in byte_slice.chunks_exact(4) {
336 flat_data.push(f32::from_le_bytes(chunk.try_into().unwrap()));
337 }
338 let vectors = FlatVectors {
339 data: flat_data,
340 dim,
341 count: n,
342 };
343
344 let ids_json = fs::read_to_string(dir.join("ids.json"))?;
346 let id_map: Vec<String> = serde_json::from_str(&ids_json)
347 .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
348
349 let mut id_reverse = HashMap::new();
350 for (i, id) in id_map.iter().enumerate() {
351 id_reverse.insert(id.clone(), i as u32);
352 }
353
354 let graph_bytes = fs::read(dir.join("graph.bin"))?;
356 let medoid = u64::from_le_bytes(graph_bytes[0..8].try_into().unwrap()) as u32;
357 let graph_n = u64::from_le_bytes(graph_bytes[8..16].try_into().unwrap()) as usize;
358
359 let mut neighbors = Vec::with_capacity(graph_n);
360 let mut offset = 16;
361 for _ in 0..graph_n {
362 let deg =
363 u32::from_le_bytes(graph_bytes[offset..offset + 4].try_into().unwrap()) as usize;
364 offset += 4;
365 let mut nbrs = Vec::with_capacity(deg);
366 for _ in 0..deg {
367 let nbr = u32::from_le_bytes(graph_bytes[offset..offset + 4].try_into().unwrap());
368 offset += 4;
369 nbrs.push(nbr);
370 }
371 neighbors.push(nbrs);
372 }
373
374 let graph = VamanaGraph {
375 neighbors,
376 medoid,
377 max_degree,
378 build_beam,
379 alpha,
380 };
381
382 let pq_path = dir.join("pq.bin");
384 let (pq, pq_codes) = if pq_path.exists() {
385 let pq_bytes = fs::read(&pq_path)?;
386 let (pq, _): (ProductQuantizer, usize) =
387 bincode::decode_from_slice(&pq_bytes, bincode::config::standard())
388 .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
389
390 let codes_bytes = fs::read(dir.join("pq_codes.bin"))?;
391 let m = pq.m;
392 let mut codes = Vec::with_capacity(n);
393 for i in 0..n {
394 codes.push(codes_bytes[i * m..(i + 1) * m].to_vec());
395 }
396 (Some(pq), codes)
397 } else {
398 (None, Vec::new())
399 };
400
401 Ok(Self {
402 config,
403 vectors,
404 id_map,
405 id_reverse,
406 graph: Some(graph),
407 pq,
408 pq_codes,
409 built: true,
410 visited: Some(VisitedSet::new(n)),
411 mmap: Some(mmap),
412 })
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use tempfile::tempdir;
420
421 fn random_vectors(n: usize, dim: usize) -> Vec<(String, Vec<f32>)> {
422 use rand::prelude::*;
423 let mut rng = rand::rngs::StdRng::seed_from_u64(0xD15CA77);
427 (0..n)
428 .map(|i| {
429 let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
430 (format!("vec-{i}"), v)
431 })
432 .collect()
433 }
434
435 fn random_data(n: usize, dim: usize) -> Vec<(String, Vec<f32>)> {
436 random_vectors(n, dim)
437 }
438
439 #[test]
440 fn test_diskann_basic() {
441 let mut index = DiskAnnIndex::new(DiskAnnConfig {
442 dim: 32,
443 max_degree: 16,
444 build_beam: 32,
445 search_beam: 32,
446 alpha: 1.2,
447 ..Default::default()
448 });
449
450 let data = random_vectors(500, 32);
451 let query = data[42].1.clone();
452
453 index.insert_batch(data).unwrap();
454 index.build().unwrap();
455
456 let results = index.search(&query, 5).unwrap();
457 assert!(!results.is_empty());
458 assert_eq!(results[0].id, "vec-42"); assert!(results[0].distance < 1e-6); }
461
462 #[test]
463 fn test_diskann_with_pq() {
464 let mut index = DiskAnnIndex::new(DiskAnnConfig {
465 dim: 32,
466 max_degree: 16,
467 build_beam: 32,
468 search_beam: 32,
469 alpha: 1.2,
470 pq_subspaces: 4,
471 pq_iterations: 5,
472 ..Default::default()
473 });
474
475 let data = random_vectors(200, 32);
476 let query = data[10].1.clone();
477
478 index.insert_batch(data).unwrap();
479 index.build().unwrap();
480
481 let results = index.search(&query, 5).unwrap();
482 assert_eq!(results[0].id, "vec-10");
483 }
484
485 #[test]
486 fn test_diskann_save_load() {
487 let dir = tempdir().unwrap();
488 let path = dir.path().join("diskann_test");
489
490 let data = random_vectors(100, 16);
491 let query = data[7].1.clone();
492
493 {
495 let mut index = DiskAnnIndex::new(DiskAnnConfig {
496 dim: 16,
497 max_degree: 8,
498 build_beam: 16,
499 search_beam: 16,
500 alpha: 1.2,
501 storage_path: Some(path.clone()),
502 ..Default::default()
503 });
504 index.insert_batch(data).unwrap();
505 index.build().unwrap();
506 }
507
508 let loaded = DiskAnnIndex::load(&path).unwrap();
510 let results = loaded.search(&query, 3).unwrap();
511 assert_eq!(results[0].id, "vec-7");
512 }
513
514 #[test]
515 fn test_recall_at_10() {
516 use rand::prelude::*;
518 let mut rng = rand::rngs::StdRng::seed_from_u64(0xD15CA77);
519 let n = 2000;
520 let dim = 64;
521 let k = 10;
522
523 let data: Vec<(String, Vec<f32>)> = (0..n)
524 .map(|i| {
525 let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
526 (format!("v{i}"), v)
527 })
528 .collect();
529
530 let mut index = DiskAnnIndex::new(DiskAnnConfig {
531 dim,
532 max_degree: 32,
533 build_beam: 64,
534 search_beam: 64,
535 alpha: 1.2,
536 ..Default::default()
537 });
538 index.insert_batch(data.clone()).unwrap();
539 index.build().unwrap();
540
541 let num_queries = 50;
543 let mut total_recall = 0.0;
544
545 for _ in 0..num_queries {
546 let qi = rng.gen_range(0..n);
547 let query = &data[qi].1;
548
549 let mut brute: Vec<(usize, f32)> = data
551 .iter()
552 .enumerate()
553 .map(|(i, (_, v))| (i, crate::distance::l2_squared(v, query)))
554 .collect();
555 brute.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
556 let gt: std::collections::HashSet<String> =
557 brute[..k].iter().map(|(i, _)| data[*i].0.clone()).collect();
558
559 let results = index.search(query, k).unwrap();
561 let found: std::collections::HashSet<String> =
562 results.iter().map(|r| r.id.clone()).collect();
563
564 let recall = gt.intersection(&found).count() as f64 / k as f64;
565 total_recall += recall;
566 }
567
568 let avg_recall = total_recall / num_queries as f64;
569 println!("Recall@{k} = {avg_recall:.3} (n={n}, dim={dim}, queries={num_queries})");
570 assert!(
571 avg_recall >= 0.85,
572 "Recall@{k} = {avg_recall:.3}, expected >= 0.85"
573 );
574 }
575
576 #[test]
577 fn test_dimension_mismatch() {
578 let mut index = DiskAnnIndex::new(DiskAnnConfig {
579 dim: 16,
580 ..Default::default()
581 });
582
583 let result = index.insert("bad".to_string(), vec![1.0; 32]);
585 assert!(result.is_err());
586
587 index.insert("ok".to_string(), vec![1.0; 16]).unwrap();
589 index.build().unwrap();
590 let result = index.search(&[1.0; 32], 1);
591 assert!(result.is_err());
592 }
593
594 #[test]
595 fn test_duplicate_id_rejected() {
596 let mut index = DiskAnnIndex::new(DiskAnnConfig {
597 dim: 4,
598 ..Default::default()
599 });
600 index.insert("a".to_string(), vec![1.0; 4]).unwrap();
601 let result = index.insert("a".to_string(), vec![2.0; 4]);
602 assert!(result.is_err());
603 }
604
605 #[test]
606 fn test_search_before_build_fails() {
607 let mut index = DiskAnnIndex::new(DiskAnnConfig {
608 dim: 4,
609 ..Default::default()
610 });
611 index.insert("a".to_string(), vec![1.0; 4]).unwrap();
612 let result = index.search(&[1.0; 4], 1);
613 assert!(result.is_err());
614 }
615
616 #[test]
617 fn test_scale_5k() {
618 use rand::prelude::*;
620 use std::time::Instant;
621 let mut rng = rand::rngs::StdRng::seed_from_u64(0xD15CA77);
622
623 let n = 5000;
624 let dim = 128;
625 let data: Vec<(String, Vec<f32>)> = (0..n)
626 .map(|i| {
627 let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
628 (format!("v{i}"), v)
629 })
630 .collect();
631
632 let mut index = DiskAnnIndex::new(DiskAnnConfig {
633 dim,
634 max_degree: 48,
635 build_beam: 96,
636 search_beam: 48,
637 alpha: 1.2,
638 ..Default::default()
639 });
640 index.insert_batch(data.clone()).unwrap();
641
642 let t0 = Instant::now();
643 index.build().unwrap();
644 let build_ms = t0.elapsed().as_millis();
645 println!("Build {n} vectors ({dim}d): {build_ms}ms");
646
647 let query = &data[0].1;
649 let t0 = Instant::now();
650 let iters = 100;
651 for _ in 0..iters {
652 let _ = index.search(query, 10).unwrap();
653 }
654 let search_us = t0.elapsed().as_micros() / iters;
655 println!("Search latency (k=10): {search_us}µs avg over {iters} queries");
656
657 assert!(
658 search_us < 10_000,
659 "Search took {search_us}µs, expected <10ms"
660 );
661 }
662}