Skip to main content

ruvector_diskann/
index.rs

1//! DiskANN index — ties together Vamana graph, PQ, and mmap persistence
2
3use crate::distance::{l2_squared, FlatVectors, VisitedSet};
4use crate::error::{DiskAnnError, Result};
5use crate::graph::VamanaGraph;
6use crate::pq::ProductQuantizer;
7use memmap2::{Mmap, MmapOptions};
8use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::{BufWriter, Write};
11use std::path::{Path, PathBuf};
12
13/// Search result
14#[derive(Debug, Clone)]
15pub struct SearchResult {
16    pub id: String,
17    pub distance: f32,
18}
19
20/// Configuration for DiskANN index
21#[derive(Debug, Clone)]
22pub struct DiskAnnConfig {
23    /// Vector dimension
24    pub dim: usize,
25    /// Maximum out-degree for Vamana graph (R)
26    pub max_degree: usize,
27    /// Search beam width during construction (L_build)
28    pub build_beam: usize,
29    /// Search beam width during query (L_search)
30    pub search_beam: usize,
31    /// Alpha parameter for robust pruning (>= 1.0)
32    pub alpha: f32,
33    /// Number of PQ subspaces (M). 0 = no PQ.
34    pub pq_subspaces: usize,
35    /// PQ training iterations
36    pub pq_iterations: usize,
37    /// Storage directory for persistence
38    pub storage_path: Option<PathBuf>,
39}
40
41impl Default for DiskAnnConfig {
42    fn default() -> Self {
43        Self {
44            dim: 128,
45            max_degree: 64,
46            build_beam: 128,
47            search_beam: 64,
48            alpha: 1.2,
49            pq_subspaces: 0,
50            pq_iterations: 10,
51            storage_path: None,
52        }
53    }
54}
55
56/// DiskANN index with Vamana graph + optional PQ + mmap persistence
57pub struct DiskAnnIndex {
58    config: DiskAnnConfig,
59    /// Flat contiguous vector storage (cache-friendly)
60    vectors: FlatVectors,
61    /// ID mapping: internal index -> external string ID
62    id_map: Vec<String>,
63    /// Reverse mapping: external ID -> internal index
64    id_reverse: HashMap<String, u32>,
65    /// Vamana graph
66    graph: Option<VamanaGraph>,
67    /// Product quantizer (optional)
68    pq: Option<ProductQuantizer>,
69    /// PQ codes for all vectors
70    pq_codes: Vec<Vec<u8>>,
71    /// Whether index has been built
72    built: bool,
73    /// Reusable visited set for search (avoids per-query allocation)
74    visited: Option<VisitedSet>,
75    /// Memory-mapped vector data (for large datasets)
76    mmap: Option<Mmap>,
77}
78
79impl DiskAnnIndex {
80    /// Create a new DiskANN index
81    pub fn new(config: DiskAnnConfig) -> Self {
82        let dim = config.dim;
83        Self {
84            config,
85            vectors: FlatVectors::new(dim),
86            id_map: Vec::new(),
87            id_reverse: HashMap::new(),
88            graph: None,
89            pq: None,
90            pq_codes: Vec::new(),
91            built: false,
92            visited: None,
93            mmap: None,
94        }
95    }
96
97    /// Insert a vector with a string ID
98    pub fn insert(&mut self, id: String, vector: Vec<f32>) -> Result<()> {
99        if vector.len() != self.config.dim {
100            return Err(DiskAnnError::DimensionMismatch {
101                expected: self.config.dim,
102                actual: vector.len(),
103            });
104        }
105        if self.id_reverse.contains_key(&id) {
106            return Err(DiskAnnError::InvalidConfig(format!("Duplicate ID: {id}")));
107        }
108
109        let idx = self.vectors.len() as u32;
110        self.id_reverse.insert(id.clone(), idx);
111        self.id_map.push(id);
112        self.vectors.push(&vector);
113        self.built = false;
114        Ok(())
115    }
116
117    /// Insert a batch of vectors
118    pub fn insert_batch(&mut self, entries: Vec<(String, Vec<f32>)>) -> Result<()> {
119        for (id, vector) in entries {
120            self.insert(id, vector)?;
121        }
122        Ok(())
123    }
124
125    /// Build the index (must be called after all inserts, before search)
126    pub fn build(&mut self) -> Result<()> {
127        let n = self.vectors.len();
128        if n == 0 {
129            return Err(DiskAnnError::Empty);
130        }
131
132        // Train PQ if configured
133        if self.config.pq_subspaces > 0 {
134            // Collect vectors for PQ training
135            let vecs: Vec<Vec<f32>> = (0..n)
136                .map(|i| self.vectors.get(i).to_vec())
137                .collect();
138            let mut pq = ProductQuantizer::new(self.config.dim, self.config.pq_subspaces)?;
139            pq.train(&vecs, self.config.pq_iterations)?;
140
141            self.pq_codes = vecs
142                .iter()
143                .map(|v| pq.encode(v))
144                .collect::<Result<Vec<_>>>()?;
145
146            self.pq = Some(pq);
147        }
148
149        // Build Vamana graph on flat storage
150        let mut graph = VamanaGraph::new(
151            n,
152            self.config.max_degree,
153            self.config.build_beam,
154            self.config.alpha,
155        );
156        graph.build(&self.vectors)?;
157        self.graph = Some(graph);
158
159        // Pre-allocate visited set for search
160        self.visited = Some(VisitedSet::new(n));
161        self.built = true;
162
163        if let Some(ref path) = self.config.storage_path {
164            self.save(path)?;
165        }
166
167        Ok(())
168    }
169
170    /// Search for k nearest neighbors
171    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
172        if !self.built {
173            return Err(DiskAnnError::NotBuilt);
174        }
175        if query.len() != self.config.dim {
176            return Err(DiskAnnError::DimensionMismatch {
177                expected: self.config.dim,
178                actual: query.len(),
179            });
180        }
181
182        let graph = self.graph.as_ref().unwrap();
183        let beam = self.config.search_beam.max(k);
184
185        let (candidates, _) = graph.greedy_search(&self.vectors, query, beam);
186
187        // Re-rank candidates with exact distance
188        let mut scored: Vec<(u32, f32)> = candidates
189            .into_iter()
190            .map(|id| (id, l2_squared(self.vectors.get(id as usize), query)))
191            .collect();
192        scored.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
193
194        Ok(scored
195            .into_iter()
196            .take(k)
197            .map(|(id, dist)| SearchResult {
198                id: self.id_map[id as usize].clone(),
199                distance: dist,
200            })
201            .collect())
202    }
203
204    /// Get the number of vectors in the index
205    pub fn count(&self) -> usize {
206        self.vectors.len()
207    }
208
209    /// Delete a vector by ID (marks as deleted, doesn't rebuild graph)
210    pub fn delete(&mut self, id: &str) -> Result<bool> {
211        if let Some(&idx) = self.id_reverse.get(id) {
212            self.vectors.zero_out(idx as usize);
213            self.id_reverse.remove(id);
214            Ok(true)
215        } else {
216            Ok(false)
217        }
218    }
219
220    /// Save index to disk
221    pub fn save(&self, dir: &Path) -> Result<()> {
222        fs::create_dir_all(dir)?;
223
224        // Save vectors as flat binary (already contiguous — mmap-friendly)
225        let vec_path = dir.join("vectors.bin");
226        let mut f = BufWriter::new(File::create(&vec_path)?);
227        let n = self.vectors.len() as u64;
228        let dim = self.config.dim as u64;
229        f.write_all(&n.to_le_bytes())?;
230        f.write_all(&dim.to_le_bytes())?;
231        // Write flat slab directly — zero copy
232        let byte_slice = unsafe {
233            std::slice::from_raw_parts(
234                self.vectors.data.as_ptr() as *const u8,
235                self.vectors.data.len() * 4,
236            )
237        };
238        f.write_all(byte_slice)?;
239        f.flush()?;
240
241        // Save graph adjacency
242        let graph_path = dir.join("graph.bin");
243        let mut f = BufWriter::new(File::create(&graph_path)?);
244        if let Some(ref graph) = self.graph {
245            f.write_all(&(graph.medoid as u64).to_le_bytes())?;
246            f.write_all(&(graph.neighbors.len() as u64).to_le_bytes())?;
247            for neighbors in &graph.neighbors {
248                f.write_all(&(neighbors.len() as u32).to_le_bytes())?;
249                for &n in neighbors {
250                    f.write_all(&n.to_le_bytes())?;
251                }
252            }
253        }
254        f.flush()?;
255
256        // Save ID map
257        let ids_path = dir.join("ids.json");
258        let ids_json = serde_json::to_string(&self.id_map)
259            .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
260        fs::write(&ids_path, ids_json)?;
261
262        // Save PQ if present
263        if let Some(ref pq) = self.pq {
264            let pq_path = dir.join("pq.bin");
265            let pq_bytes = bincode::encode_to_vec(pq, bincode::config::standard())
266                .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
267            fs::write(&pq_path, pq_bytes)?;
268
269            // Save PQ codes
270            let codes_path = dir.join("pq_codes.bin");
271            let mut f = BufWriter::new(File::create(&codes_path)?);
272            for codes in &self.pq_codes {
273                f.write_all(codes)?;
274            }
275            f.flush()?;
276        }
277
278        // Save config
279        let config_path = dir.join("config.json");
280        let config_json = serde_json::json!({
281            "dim": self.config.dim,
282            "max_degree": self.config.max_degree,
283            "build_beam": self.config.build_beam,
284            "search_beam": self.config.search_beam,
285            "alpha": self.config.alpha,
286            "pq_subspaces": self.config.pq_subspaces,
287            "count": self.vectors.len(),
288            "built": self.built,
289        });
290        fs::write(&config_path, serde_json::to_string_pretty(&config_json).unwrap())?;
291
292        Ok(())
293    }
294
295    /// Load index from disk with memory-mapped vectors
296    pub fn load(dir: &Path) -> Result<Self> {
297        // Load config
298        let config_json: serde_json::Value =
299            serde_json::from_str(&fs::read_to_string(dir.join("config.json"))?)
300                .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
301
302        let dim = config_json["dim"].as_u64().unwrap() as usize;
303        let max_degree = config_json["max_degree"].as_u64().unwrap() as usize;
304        let build_beam = config_json["build_beam"].as_u64().unwrap() as usize;
305        let search_beam = config_json["search_beam"].as_u64().unwrap() as usize;
306        let alpha = config_json["alpha"].as_f64().unwrap() as f32;
307        let pq_subspaces = config_json["pq_subspaces"].as_u64().unwrap_or(0) as usize;
308
309        let config = DiskAnnConfig {
310            dim,
311            max_degree,
312            build_beam,
313            search_beam,
314            alpha,
315            pq_subspaces,
316            storage_path: Some(dir.to_path_buf()),
317            ..Default::default()
318        };
319
320        // Load vectors via mmap
321        let vec_file = File::open(dir.join("vectors.bin"))?;
322        let mmap = unsafe { MmapOptions::new().map(&vec_file)? };
323
324        let n = u64::from_le_bytes(mmap[0..8].try_into().unwrap()) as usize;
325        let file_dim = u64::from_le_bytes(mmap[8..16].try_into().unwrap()) as usize;
326        assert_eq!(file_dim, dim);
327
328        // Load vectors directly into flat slab from mmap
329        let data_start = 16;
330        let total_floats = n * dim;
331        let mut flat_data = Vec::with_capacity(total_floats);
332        let byte_slice = &mmap[data_start..data_start + total_floats * 4];
333        // Safe: f32 from le bytes
334        for chunk in byte_slice.chunks_exact(4) {
335            flat_data.push(f32::from_le_bytes(chunk.try_into().unwrap()));
336        }
337        let vectors = FlatVectors {
338            data: flat_data,
339            dim,
340            count: n,
341        };
342
343        // Load IDs
344        let ids_json = fs::read_to_string(dir.join("ids.json"))?;
345        let id_map: Vec<String> = serde_json::from_str(&ids_json)
346            .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
347
348        let mut id_reverse = HashMap::new();
349        for (i, id) in id_map.iter().enumerate() {
350            id_reverse.insert(id.clone(), i as u32);
351        }
352
353        // Load graph
354        let graph_bytes = fs::read(dir.join("graph.bin"))?;
355        let medoid = u64::from_le_bytes(graph_bytes[0..8].try_into().unwrap()) as u32;
356        let graph_n = u64::from_le_bytes(graph_bytes[8..16].try_into().unwrap()) as usize;
357
358        let mut neighbors = Vec::with_capacity(graph_n);
359        let mut offset = 16;
360        for _ in 0..graph_n {
361            let deg = u32::from_le_bytes(graph_bytes[offset..offset + 4].try_into().unwrap()) as usize;
362            offset += 4;
363            let mut nbrs = Vec::with_capacity(deg);
364            for _ in 0..deg {
365                let nbr = u32::from_le_bytes(graph_bytes[offset..offset + 4].try_into().unwrap());
366                offset += 4;
367                nbrs.push(nbr);
368            }
369            neighbors.push(nbrs);
370        }
371
372        let graph = VamanaGraph {
373            neighbors,
374            medoid,
375            max_degree,
376            build_beam,
377            alpha,
378        };
379
380        // Load PQ if present
381        let pq_path = dir.join("pq.bin");
382        let (pq, pq_codes) = if pq_path.exists() {
383            let pq_bytes = fs::read(&pq_path)?;
384            let (pq, _): (ProductQuantizer, usize) =
385                bincode::decode_from_slice(&pq_bytes, bincode::config::standard())
386                    .map_err(|e| DiskAnnError::Serialization(e.to_string()))?;
387
388            let codes_bytes = fs::read(dir.join("pq_codes.bin"))?;
389            let m = pq.m;
390            let mut codes = Vec::with_capacity(n);
391            for i in 0..n {
392                codes.push(codes_bytes[i * m..(i + 1) * m].to_vec());
393            }
394            (Some(pq), codes)
395        } else {
396            (None, Vec::new())
397        };
398
399        Ok(Self {
400            config,
401            vectors,
402            id_map,
403            id_reverse,
404            graph: Some(graph),
405            pq,
406            pq_codes,
407            built: true,
408            visited: Some(VisitedSet::new(n)),
409            mmap: Some(mmap),
410        })
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use tempfile::tempdir;
418
419    fn random_vectors(n: usize, dim: usize) -> Vec<(String, Vec<f32>)> {
420        use rand::prelude::*;
421        let mut rng = rand::thread_rng();
422        (0..n)
423            .map(|i| {
424                let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
425                (format!("vec-{i}"), v)
426            })
427            .collect()
428    }
429
430    fn random_data(n: usize, dim: usize) -> Vec<(String, Vec<f32>)> {
431        random_vectors(n, dim)
432    }
433
434    #[test]
435    fn test_diskann_basic() {
436        let mut index = DiskAnnIndex::new(DiskAnnConfig {
437            dim: 32,
438            max_degree: 16,
439            build_beam: 32,
440            search_beam: 32,
441            alpha: 1.2,
442            ..Default::default()
443        });
444
445        let data = random_vectors(500, 32);
446        let query = data[42].1.clone();
447
448        index.insert_batch(data).unwrap();
449        index.build().unwrap();
450
451        let results = index.search(&query, 5).unwrap();
452        assert!(!results.is_empty());
453        assert_eq!(results[0].id, "vec-42"); // Should find itself
454        assert!(results[0].distance < 1e-6); // Exact match
455    }
456
457    #[test]
458    fn test_diskann_with_pq() {
459        let mut index = DiskAnnIndex::new(DiskAnnConfig {
460            dim: 32,
461            max_degree: 16,
462            build_beam: 32,
463            search_beam: 32,
464            alpha: 1.2,
465            pq_subspaces: 4,
466            pq_iterations: 5,
467            ..Default::default()
468        });
469
470        let data = random_vectors(200, 32);
471        let query = data[10].1.clone();
472
473        index.insert_batch(data).unwrap();
474        index.build().unwrap();
475
476        let results = index.search(&query, 5).unwrap();
477        assert_eq!(results[0].id, "vec-10");
478    }
479
480    #[test]
481    fn test_diskann_save_load() {
482        let dir = tempdir().unwrap();
483        let path = dir.path().join("diskann_test");
484
485        let data = random_vectors(100, 16);
486        let query = data[7].1.clone();
487
488        // Build and save
489        {
490            let mut index = DiskAnnIndex::new(DiskAnnConfig {
491                dim: 16,
492                max_degree: 8,
493                build_beam: 16,
494                search_beam: 16,
495                alpha: 1.2,
496                storage_path: Some(path.clone()),
497                ..Default::default()
498            });
499            index.insert_batch(data).unwrap();
500            index.build().unwrap();
501        }
502
503        // Load and search
504        let loaded = DiskAnnIndex::load(&path).unwrap();
505        let results = loaded.search(&query, 3).unwrap();
506        assert_eq!(results[0].id, "vec-7");
507    }
508
509    #[test]
510    fn test_recall_at_10() {
511        // Measure recall@10: what fraction of true top-10 neighbors does DiskANN find?
512        use rand::prelude::*;
513        let mut rng = rand::thread_rng();
514        let n = 2000;
515        let dim = 64;
516        let k = 10;
517
518        let data: Vec<(String, Vec<f32>)> = (0..n)
519            .map(|i| {
520                let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
521                (format!("v{i}"), v)
522            })
523            .collect();
524
525        let mut index = DiskAnnIndex::new(DiskAnnConfig {
526            dim,
527            max_degree: 32,
528            build_beam: 64,
529            search_beam: 64,
530            alpha: 1.2,
531            ..Default::default()
532        });
533        index.insert_batch(data.clone()).unwrap();
534        index.build().unwrap();
535
536        // Test 50 random queries
537        let num_queries = 50;
538        let mut total_recall = 0.0;
539
540        for _ in 0..num_queries {
541            let qi = rng.gen_range(0..n);
542            let query = &data[qi].1;
543
544            // Brute-force ground truth
545            let mut brute: Vec<(usize, f32)> = data
546                .iter()
547                .enumerate()
548                .map(|(i, (_, v))| (i, crate::distance::l2_squared(v, query)))
549                .collect();
550            brute.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
551            let gt: std::collections::HashSet<String> = brute[..k]
552                .iter()
553                .map(|(i, _)| data[*i].0.clone())
554                .collect();
555
556            // DiskANN search
557            let results = index.search(query, k).unwrap();
558            let found: std::collections::HashSet<String> =
559                results.iter().map(|r| r.id.clone()).collect();
560
561            let recall = gt.intersection(&found).count() as f64 / k as f64;
562            total_recall += recall;
563        }
564
565        let avg_recall = total_recall / num_queries as f64;
566        println!("Recall@{k} = {avg_recall:.3} (n={n}, dim={dim}, queries={num_queries})");
567        assert!(
568            avg_recall >= 0.85,
569            "Recall@{k} = {avg_recall:.3}, expected >= 0.85"
570        );
571    }
572
573    #[test]
574    fn test_dimension_mismatch() {
575        let mut index = DiskAnnIndex::new(DiskAnnConfig {
576            dim: 16,
577            ..Default::default()
578        });
579
580        // Wrong dimension on insert
581        let result = index.insert("bad".to_string(), vec![1.0; 32]);
582        assert!(result.is_err());
583
584        // Wrong dimension on search
585        index.insert("ok".to_string(), vec![1.0; 16]).unwrap();
586        index.build().unwrap();
587        let result = index.search(&[1.0; 32], 1);
588        assert!(result.is_err());
589    }
590
591    #[test]
592    fn test_duplicate_id_rejected() {
593        let mut index = DiskAnnIndex::new(DiskAnnConfig {
594            dim: 4,
595            ..Default::default()
596        });
597        index.insert("a".to_string(), vec![1.0; 4]).unwrap();
598        let result = index.insert("a".to_string(), vec![2.0; 4]);
599        assert!(result.is_err());
600    }
601
602    #[test]
603    fn test_search_before_build_fails() {
604        let mut index = DiskAnnIndex::new(DiskAnnConfig {
605            dim: 4,
606            ..Default::default()
607        });
608        index.insert("a".to_string(), vec![1.0; 4]).unwrap();
609        let result = index.search(&[1.0; 4], 1);
610        assert!(result.is_err());
611    }
612
613    #[test]
614    fn test_scale_5k() {
615        // 5000 vectors, 128-dim — should build in under 5 seconds
616        use std::time::Instant;
617        use rand::prelude::*;
618        let mut rng = rand::thread_rng();
619
620        let n = 5000;
621        let dim = 128;
622        let data: Vec<(String, Vec<f32>)> = (0..n)
623            .map(|i| {
624                let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
625                (format!("v{i}"), v)
626            })
627            .collect();
628
629        let mut index = DiskAnnIndex::new(DiskAnnConfig {
630            dim,
631            max_degree: 48,
632            build_beam: 96,
633            search_beam: 48,
634            alpha: 1.2,
635            ..Default::default()
636        });
637        index.insert_batch(data.clone()).unwrap();
638
639        let t0 = Instant::now();
640        index.build().unwrap();
641        let build_ms = t0.elapsed().as_millis();
642        println!("Build {n} vectors ({dim}d): {build_ms}ms");
643
644        // Search latency
645        let query = &data[0].1;
646        let t0 = Instant::now();
647        let iters = 100;
648        for _ in 0..iters {
649            let _ = index.search(query, 10).unwrap();
650        }
651        let search_us = t0.elapsed().as_micros() / iters;
652        println!("Search latency (k=10): {search_us}µs avg over {iters} queries");
653
654        assert!(search_us < 10_000, "Search took {search_us}µs, expected <10ms");
655    }
656}