Skip to main content

ruvector_diskann/
index.rs

1//! DiskANN index — ties together Vamana graph, PQ, and mmap persistence
2
3use crate::distance::{l2_squared, FlatVectors, VisitedSet};
4use crate::error::{DiskAnnError, Result};
5use crate::graph::VamanaGraph;
6use crate::pq::ProductQuantizer;
7use memmap2::{Mmap, MmapOptions};
8use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::{BufWriter, Write};
11use std::path::{Path, PathBuf};
12
13/// Search result
14#[derive(Debug, Clone)]
15pub struct SearchResult {
16    pub id: String,
17    pub distance: f32,
18}
19
20/// Configuration for DiskANN index
21#[derive(Debug, Clone)]
22pub struct DiskAnnConfig {
23    /// Vector dimension
24    pub dim: usize,
25    /// Maximum out-degree for Vamana graph (R)
26    pub max_degree: usize,
27    /// Search beam width during construction (L_build)
28    pub build_beam: usize,
29    /// Search beam width during query (L_search)
30    pub search_beam: usize,
31    /// Alpha parameter for robust pruning (>= 1.0)
32    pub alpha: f32,
33    /// Number of PQ subspaces (M). 0 = no PQ.
34    pub pq_subspaces: usize,
35    /// PQ training iterations
36    pub pq_iterations: usize,
37    /// Storage directory for persistence
38    pub storage_path: Option<PathBuf>,
39}
40
41impl Default for DiskAnnConfig {
42    fn default() -> Self {
43        Self {
44            dim: 128,
45            max_degree: 64,
46            build_beam: 128,
47            search_beam: 64,
48            alpha: 1.2,
49            pq_subspaces: 0,
50            pq_iterations: 10,
51            storage_path: None,
52        }
53    }
54}
55
56/// DiskANN index with Vamana graph + optional PQ + mmap persistence
57pub struct DiskAnnIndex {
58    config: DiskAnnConfig,
59    /// Flat contiguous vector storage (cache-friendly)
60    vectors: FlatVectors,
61    /// ID mapping: internal index -> external string ID
62    id_map: Vec<String>,
63    /// Reverse mapping: external ID -> internal index
64    id_reverse: HashMap<String, u32>,
65    /// Vamana graph
66    graph: Option<VamanaGraph>,
67    /// Product quantizer (optional)
68    pq: Option<ProductQuantizer>,
69    /// PQ codes for all vectors
70    pq_codes: Vec<Vec<u8>>,
71    /// Whether index has been built
72    built: bool,
73    /// Reusable visited set for search (avoids per-query allocation)
74    visited: Option<VisitedSet>,
75    /// Memory-mapped vector data (for large datasets)
76    mmap: Option<Mmap>,
77}
78
79impl DiskAnnIndex {
80    /// Create a new DiskANN index
81    pub fn new(config: DiskAnnConfig) -> Self {
82        let dim = config.dim;
83        Self {
84            config,
85            vectors: FlatVectors::new(dim),
86            id_map: Vec::new(),
87            id_reverse: HashMap::new(),
88            graph: None,
89            pq: None,
90            pq_codes: Vec::new(),
91            built: false,
92            visited: None,
93            mmap: None,
94        }
95    }
96
97    /// Insert a vector with a string ID
98    pub fn insert(&mut self, id: String, vector: Vec<f32>) -> Result<()> {
99        if vector.len() != self.config.dim {
100            return Err(DiskAnnError::DimensionMismatch {
101                expected: self.config.dim,
102                actual: vector.len(),
103            });
104        }
105        if self.id_reverse.contains_key(&id) {
106            return Err(DiskAnnError::InvalidConfig(format!("Duplicate ID: {id}")));
107        }
108
109        let idx = self.vectors.len() as u32;
110        self.id_reverse.insert(id.clone(), idx);
111        self.id_map.push(id);
112        self.vectors.push(&vector);
113        self.built = false;
114        Ok(())
115    }
116
117    /// Insert a batch of vectors
118    pub fn insert_batch(&mut self, entries: Vec<(String, Vec<f32>)>) -> Result<()> {
119        for (id, vector) in entries {
120            self.insert(id, vector)?;
121        }
122        Ok(())
123    }
124
125    /// Build the index (must be called after all inserts, before search)
126    pub fn build(&mut self) -> Result<()> {
127        let n = self.vectors.len();
128        if n == 0 {
129            return Err(DiskAnnError::Empty);
130        }
131
132        // Train PQ if configured
133        if self.config.pq_subspaces > 0 {
134            // Collect vectors for PQ training
135            let vecs: Vec<Vec<f32>> = (0..n).map(|i| self.vectors.get(i).to_vec()).collect();
136            let mut pq = ProductQuantizer::new(self.config.dim, self.config.pq_subspaces)?;
137            pq.train(&vecs, self.config.pq_iterations)?;
138
139            self.pq_codes = vecs
140                .iter()
141                .map(|v| pq.encode(v))
142                .collect::<Result<Vec<_>>>()?;
143
144            self.pq = Some(pq);
145        }
146
147        // Build Vamana graph on flat storage
148        let mut graph = VamanaGraph::new(
149            n,
150            self.config.max_degree,
151            self.config.build_beam,
152            self.config.alpha,
153        );
154        graph.build(&self.vectors)?;
155        self.graph = Some(graph);
156
157        // Pre-allocate visited set for search
158        self.visited = Some(VisitedSet::new(n));
159        self.built = true;
160
161        if let Some(ref path) = self.config.storage_path {
162            self.save(path)?;
163        }
164
165        Ok(())
166    }
167
168    /// Search for k nearest neighbors
169    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
170        if !self.built {
171            return Err(DiskAnnError::NotBuilt);
172        }
173        if query.len() != self.config.dim {
174            return Err(DiskAnnError::DimensionMismatch {
175                expected: self.config.dim,
176                actual: query.len(),
177            });
178        }
179
180        let graph = self.graph.as_ref().unwrap();
181        let beam = self.config.search_beam.max(k);
182
183        let (candidates, _) = graph.greedy_search(&self.vectors, query, beam);
184
185        // Re-rank candidates with exact distance
186        let mut scored: Vec<(u32, f32)> = candidates
187            .into_iter()
188            .map(|id| (id, l2_squared(self.vectors.get(id as usize), query)))
189            .collect();
190        scored.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
191
192        Ok(scored
193            .into_iter()
194            .take(k)
195            .map(|(id, dist)| SearchResult {
196                id: self.id_map[id as usize].clone(),
197                distance: dist,
198            })
199            .collect())
200    }
201
202    /// Get the number of vectors in the index
203    pub fn count(&self) -> usize {
204        self.vectors.len()
205    }
206
207    /// Delete a vector by ID (marks as deleted, doesn't rebuild graph)
208    pub fn delete(&mut self, id: &str) -> Result<bool> {
209        if let Some(&idx) = self.id_reverse.get(id) {
210            self.vectors.zero_out(idx as usize);
211            self.id_reverse.remove(id);
212            Ok(true)
213        } else {
214            Ok(false)
215        }
216    }
217
218    /// Save index to disk
219    pub fn save(&self, dir: &Path) -> Result<()> {
220        fs::create_dir_all(dir)?;
221
222        // Save vectors as flat binary (already contiguous — mmap-friendly)
223        let vec_path = dir.join("vectors.bin");
224        let mut f = BufWriter::new(File::create(&vec_path)?);
225        let n = self.vectors.len() as u64;
226        let dim = self.config.dim as u64;
227        f.write_all(&n.to_le_bytes())?;
228        f.write_all(&dim.to_le_bytes())?;
229        // Write flat slab directly — zero copy
230        let byte_slice = unsafe {
231            std::slice::from_raw_parts(
232                self.vectors.data.as_ptr() as *const u8,
233                self.vectors.data.len() * 4,
234            )
235        };
236        f.write_all(byte_slice)?;
237        f.flush()?;
238
239        // Save graph adjacency
240        let graph_path = dir.join("graph.bin");
241        let mut f = BufWriter::new(File::create(&graph_path)?);
242        if let Some(ref graph) = self.graph {
243            f.write_all(&(graph.medoid as u64).to_le_bytes())?;
244            f.write_all(&(graph.neighbors.len() as u64).to_le_bytes())?;
245            for neighbors in &graph.neighbors {
246                f.write_all(&(neighbors.len() as u32).to_le_bytes())?;
247                for &n in neighbors {
248                    f.write_all(&n.to_le_bytes())?;
249                }
250            }
251        }
252        f.flush()?;
253
254        // Save ID map
255        let ids_path = dir.join("ids.json");
256        let ids_json = serde_json::to_string(&self.id_map)
257            .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
258        fs::write(&ids_path, ids_json)?;
259
260        // Save PQ if present
261        if let Some(ref pq) = self.pq {
262            let pq_path = dir.join("pq.bin");
263            let pq_bytes = bincode::encode_to_vec(pq, bincode::config::standard())
264                .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
265            fs::write(&pq_path, pq_bytes)?;
266
267            // Save PQ codes
268            let codes_path = dir.join("pq_codes.bin");
269            let mut f = BufWriter::new(File::create(&codes_path)?);
270            for codes in &self.pq_codes {
271                f.write_all(codes)?;
272            }
273            f.flush()?;
274        }
275
276        // Save config
277        let config_path = dir.join("config.json");
278        let config_json = serde_json::json!({
279            "dim": self.config.dim,
280            "max_degree": self.config.max_degree,
281            "build_beam": self.config.build_beam,
282            "search_beam": self.config.search_beam,
283            "alpha": self.config.alpha,
284            "pq_subspaces": self.config.pq_subspaces,
285            "count": self.vectors.len(),
286            "built": self.built,
287        });
288        fs::write(
289            &config_path,
290            serde_json::to_string_pretty(&config_json).unwrap(),
291        )?;
292
293        Ok(())
294    }
295
296    /// Load index from disk with memory-mapped vectors
297    pub fn load(dir: &Path) -> Result<Self> {
298        // Load config
299        let config_json: serde_json::Value =
300            serde_json::from_str(&fs::read_to_string(dir.join("config.json"))?)
301                .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
302
303        let dim = config_json["dim"].as_u64().unwrap() as usize;
304        let max_degree = config_json["max_degree"].as_u64().unwrap() as usize;
305        let build_beam = config_json["build_beam"].as_u64().unwrap() as usize;
306        let search_beam = config_json["search_beam"].as_u64().unwrap() as usize;
307        let alpha = config_json["alpha"].as_f64().unwrap() as f32;
308        let pq_subspaces = config_json["pq_subspaces"].as_u64().unwrap_or(0) as usize;
309
310        let config = DiskAnnConfig {
311            dim,
312            max_degree,
313            build_beam,
314            search_beam,
315            alpha,
316            pq_subspaces,
317            storage_path: Some(dir.to_path_buf()),
318            ..Default::default()
319        };
320
321        // Load vectors via mmap
322        let vec_file = File::open(dir.join("vectors.bin"))?;
323        let mmap = unsafe { MmapOptions::new().map(&vec_file)? };
324
325        let n = u64::from_le_bytes(mmap[0..8].try_into().unwrap()) as usize;
326        let file_dim = u64::from_le_bytes(mmap[8..16].try_into().unwrap()) as usize;
327        assert_eq!(file_dim, dim);
328
329        // Load vectors directly into flat slab from mmap
330        let data_start = 16;
331        let total_floats = n * dim;
332        let mut flat_data = Vec::with_capacity(total_floats);
333        let byte_slice = &mmap[data_start..data_start + total_floats * 4];
334        // Safe: f32 from le bytes
335        for chunk in byte_slice.chunks_exact(4) {
336            flat_data.push(f32::from_le_bytes(chunk.try_into().unwrap()));
337        }
338        let vectors = FlatVectors {
339            data: flat_data,
340            dim,
341            count: n,
342        };
343
344        // Load IDs
345        let ids_json = fs::read_to_string(dir.join("ids.json"))?;
346        let id_map: Vec<String> = serde_json::from_str(&ids_json)
347            .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
348
349        let mut id_reverse = HashMap::new();
350        for (i, id) in id_map.iter().enumerate() {
351            id_reverse.insert(id.clone(), i as u32);
352        }
353
354        // Load graph
355        let graph_bytes = fs::read(dir.join("graph.bin"))?;
356        let medoid = u64::from_le_bytes(graph_bytes[0..8].try_into().unwrap()) as u32;
357        let graph_n = u64::from_le_bytes(graph_bytes[8..16].try_into().unwrap()) as usize;
358
359        let mut neighbors = Vec::with_capacity(graph_n);
360        let mut offset = 16;
361        for _ in 0..graph_n {
362            let deg =
363                u32::from_le_bytes(graph_bytes[offset..offset + 4].try_into().unwrap()) as usize;
364            offset += 4;
365            let mut nbrs = Vec::with_capacity(deg);
366            for _ in 0..deg {
367                let nbr = u32::from_le_bytes(graph_bytes[offset..offset + 4].try_into().unwrap());
368                offset += 4;
369                nbrs.push(nbr);
370            }
371            neighbors.push(nbrs);
372        }
373
374        let graph = VamanaGraph {
375            neighbors,
376            medoid,
377            max_degree,
378            build_beam,
379            alpha,
380        };
381
382        // Load PQ if present
383        let pq_path = dir.join("pq.bin");
384        let (pq, pq_codes) = if pq_path.exists() {
385            let pq_bytes = fs::read(&pq_path)?;
386            let (pq, _): (ProductQuantizer, usize) =
387                bincode::decode_from_slice(&pq_bytes, bincode::config::standard())
388                    .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
389
390            let codes_bytes = fs::read(dir.join("pq_codes.bin"))?;
391            let m = pq.m;
392            let mut codes = Vec::with_capacity(n);
393            for i in 0..n {
394                codes.push(codes_bytes[i * m..(i + 1) * m].to_vec());
395            }
396            (Some(pq), codes)
397        } else {
398            (None, Vec::new())
399        };
400
401        Ok(Self {
402            config,
403            vectors,
404            id_map,
405            id_reverse,
406            graph: Some(graph),
407            pq,
408            pq_codes,
409            built: true,
410            visited: Some(VisitedSet::new(n)),
411            mmap: Some(mmap),
412        })
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use tempfile::tempdir;
420
421    fn random_vectors(n: usize, dim: usize) -> Vec<(String, Vec<f32>)> {
422        use rand::prelude::*;
423        // Seeded so tests are deterministic across CI runs — random data made
424        // basic-search assertions (nearest of vec-X is vec-X) flake when the
425        // ANN graph traversal happened to land on an unrelated near-duplicate.
426        let mut rng = rand::rngs::StdRng::seed_from_u64(0xD15CA77);
427        (0..n)
428            .map(|i| {
429                let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
430                (format!("vec-{i}"), v)
431            })
432            .collect()
433    }
434
435    fn random_data(n: usize, dim: usize) -> Vec<(String, Vec<f32>)> {
436        random_vectors(n, dim)
437    }
438
439    #[test]
440    fn test_diskann_basic() {
441        let mut index = DiskAnnIndex::new(DiskAnnConfig {
442            dim: 32,
443            max_degree: 16,
444            build_beam: 32,
445            search_beam: 32,
446            alpha: 1.2,
447            ..Default::default()
448        });
449
450        let data = random_vectors(500, 32);
451        let query = data[42].1.clone();
452
453        index.insert_batch(data).unwrap();
454        index.build().unwrap();
455
456        let results = index.search(&query, 5).unwrap();
457        assert!(!results.is_empty());
458        assert_eq!(results[0].id, "vec-42"); // Should find itself
459        assert!(results[0].distance < 1e-6); // Exact match
460    }
461
462    #[test]
463    fn test_diskann_with_pq() {
464        let mut index = DiskAnnIndex::new(DiskAnnConfig {
465            dim: 32,
466            max_degree: 16,
467            build_beam: 32,
468            search_beam: 32,
469            alpha: 1.2,
470            pq_subspaces: 4,
471            pq_iterations: 5,
472            ..Default::default()
473        });
474
475        let data = random_vectors(200, 32);
476        let query = data[10].1.clone();
477
478        index.insert_batch(data).unwrap();
479        index.build().unwrap();
480
481        let results = index.search(&query, 5).unwrap();
482        assert_eq!(results[0].id, "vec-10");
483    }
484
485    #[test]
486    fn test_diskann_save_load() {
487        let dir = tempdir().unwrap();
488        let path = dir.path().join("diskann_test");
489
490        let data = random_vectors(100, 16);
491        let query = data[7].1.clone();
492
493        // Build and save
494        {
495            let mut index = DiskAnnIndex::new(DiskAnnConfig {
496                dim: 16,
497                max_degree: 8,
498                build_beam: 16,
499                search_beam: 16,
500                alpha: 1.2,
501                storage_path: Some(path.clone()),
502                ..Default::default()
503            });
504            index.insert_batch(data).unwrap();
505            index.build().unwrap();
506        }
507
508        // Load and search
509        let loaded = DiskAnnIndex::load(&path).unwrap();
510        let results = loaded.search(&query, 3).unwrap();
511        assert_eq!(results[0].id, "vec-7");
512    }
513
514    #[test]
515    fn test_recall_at_10() {
516        // Measure recall@10: what fraction of true top-10 neighbors does DiskANN find?
517        use rand::prelude::*;
518        let mut rng = rand::rngs::StdRng::seed_from_u64(0xD15CA77);
519        let n = 2000;
520        let dim = 64;
521        let k = 10;
522
523        let data: Vec<(String, Vec<f32>)> = (0..n)
524            .map(|i| {
525                let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
526                (format!("v{i}"), v)
527            })
528            .collect();
529
530        let mut index = DiskAnnIndex::new(DiskAnnConfig {
531            dim,
532            max_degree: 32,
533            build_beam: 64,
534            search_beam: 64,
535            alpha: 1.2,
536            ..Default::default()
537        });
538        index.insert_batch(data.clone()).unwrap();
539        index.build().unwrap();
540
541        // Test 50 random queries
542        let num_queries = 50;
543        let mut total_recall = 0.0;
544
545        for _ in 0..num_queries {
546            let qi = rng.gen_range(0..n);
547            let query = &data[qi].1;
548
549            // Brute-force ground truth
550            let mut brute: Vec<(usize, f32)> = data
551                .iter()
552                .enumerate()
553                .map(|(i, (_, v))| (i, crate::distance::l2_squared(v, query)))
554                .collect();
555            brute.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
556            let gt: std::collections::HashSet<String> =
557                brute[..k].iter().map(|(i, _)| data[*i].0.clone()).collect();
558
559            // DiskANN search
560            let results = index.search(query, k).unwrap();
561            let found: std::collections::HashSet<String> =
562                results.iter().map(|r| r.id.clone()).collect();
563
564            let recall = gt.intersection(&found).count() as f64 / k as f64;
565            total_recall += recall;
566        }
567
568        let avg_recall = total_recall / num_queries as f64;
569        println!("Recall@{k} = {avg_recall:.3} (n={n}, dim={dim}, queries={num_queries})");
570        assert!(
571            avg_recall >= 0.85,
572            "Recall@{k} = {avg_recall:.3}, expected >= 0.85"
573        );
574    }
575
576    #[test]
577    fn test_dimension_mismatch() {
578        let mut index = DiskAnnIndex::new(DiskAnnConfig {
579            dim: 16,
580            ..Default::default()
581        });
582
583        // Wrong dimension on insert
584        let result = index.insert("bad".to_string(), vec![1.0; 32]);
585        assert!(result.is_err());
586
587        // Wrong dimension on search
588        index.insert("ok".to_string(), vec![1.0; 16]).unwrap();
589        index.build().unwrap();
590        let result = index.search(&[1.0; 32], 1);
591        assert!(result.is_err());
592    }
593
594    #[test]
595    fn test_duplicate_id_rejected() {
596        let mut index = DiskAnnIndex::new(DiskAnnConfig {
597            dim: 4,
598            ..Default::default()
599        });
600        index.insert("a".to_string(), vec![1.0; 4]).unwrap();
601        let result = index.insert("a".to_string(), vec![2.0; 4]);
602        assert!(result.is_err());
603    }
604
605    #[test]
606    fn test_search_before_build_fails() {
607        let mut index = DiskAnnIndex::new(DiskAnnConfig {
608            dim: 4,
609            ..Default::default()
610        });
611        index.insert("a".to_string(), vec![1.0; 4]).unwrap();
612        let result = index.search(&[1.0; 4], 1);
613        assert!(result.is_err());
614    }
615
616    #[test]
617    fn test_scale_5k() {
618        // 5000 vectors, 128-dim — should build in under 5 seconds
619        use rand::prelude::*;
620        use std::time::Instant;
621        let mut rng = rand::rngs::StdRng::seed_from_u64(0xD15CA77);
622
623        let n = 5000;
624        let dim = 128;
625        let data: Vec<(String, Vec<f32>)> = (0..n)
626            .map(|i| {
627                let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
628                (format!("v{i}"), v)
629            })
630            .collect();
631
632        let mut index = DiskAnnIndex::new(DiskAnnConfig {
633            dim,
634            max_degree: 48,
635            build_beam: 96,
636            search_beam: 48,
637            alpha: 1.2,
638            ..Default::default()
639        });
640        index.insert_batch(data.clone()).unwrap();
641
642        let t0 = Instant::now();
643        index.build().unwrap();
644        let build_ms = t0.elapsed().as_millis();
645        println!("Build {n} vectors ({dim}d): {build_ms}ms");
646
647        // Search latency
648        let query = &data[0].1;
649        let t0 = Instant::now();
650        let iters = 100;
651        for _ in 0..iters {
652            let _ = index.search(query, 10).unwrap();
653        }
654        let search_us = t0.elapsed().as_micros() / iters;
655        println!("Search latency (k=10): {search_us}µs avg over {iters} queries");
656
657        assert!(
658            search_us < 10_000,
659            "Search took {search_us}µs, expected <10ms"
660        );
661    }
662}