Skip to main content

superbit/
index.rs

1use std::sync::Arc;
2
3use hashbrown::HashMap;
4use ndarray::Array1;
5use parking_lot::RwLock;
6use rand::rngs::StdRng;
7use rand::SeedableRng;
8
9use crate::distance::{self, DistanceMetric};
10use crate::error::{LshError, Result};
11use crate::hash::{multi_probe_keys, RandomProjectionHasher};
12use crate::metrics::{MetricsCollector, MetricsSnapshot, QueryTimer};
13
14/// Configuration for the LSH index.
15#[derive(Debug, Clone)]
16#[cfg_attr(
17    feature = "persistence",
18    derive(serde::Serialize, serde::Deserialize)
19)]
20pub struct IndexConfig {
21    /// Dimensionality of vectors.
22    pub dim: usize,
23    /// Number of hash bits per table (1..=64).
24    pub num_hashes: usize,
25    /// Number of independent hash tables.
26    pub num_tables: usize,
27    /// Number of extra buckets to probe per table during queries.
28    pub num_probes: usize,
29    /// Distance metric for ranking candidates.
30    pub distance_metric: DistanceMetric,
31    /// Whether to L2-normalize vectors on insertion (recommended for cosine).
32    pub normalize_vectors: bool,
33    /// Optional RNG seed for reproducible projections.
34    pub seed: Option<u64>,
35}
36
37impl Default for IndexConfig {
38    fn default() -> Self {
39        Self {
40            dim: 768,
41            num_hashes: 8,
42            num_tables: 16,
43            num_probes: 3,
44            distance_metric: DistanceMetric::Cosine,
45            normalize_vectors: true,
46            seed: None,
47        }
48    }
49}
50
51/// A single nearest-neighbor result.
52#[derive(Debug, Clone)]
53pub struct QueryResult {
54    /// The vector ID.
55    pub id: usize,
56    /// Distance from the query vector (lower is closer).
57    pub distance: f32,
58}
59
60/// Aggregate statistics about the index.
61#[derive(Debug, Clone)]
62pub struct IndexStats {
63    pub num_vectors: usize,
64    pub num_tables: usize,
65    pub num_hashes: usize,
66    pub dimension: usize,
67    pub total_buckets: usize,
68    pub avg_bucket_size: f64,
69    pub max_bucket_size: usize,
70    pub memory_estimate_bytes: usize,
71}
72
73impl std::fmt::Display for IndexStats {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        write!(
76            f,
77            "LshIndex {{ vectors: {}, tables: {}, hashes/table: {}, dim: {}, \
78             buckets: {}, avg_bucket: {:.1}, max_bucket: {}, mem: ~{:.1}MB }}",
79            self.num_vectors,
80            self.num_tables,
81            self.num_hashes,
82            self.dimension,
83            self.total_buckets,
84            self.avg_bucket_size,
85            self.max_bucket_size,
86            self.memory_estimate_bytes as f64 / (1024.0 * 1024.0),
87        )
88    }
89}
90
91// ---------------------------------------------------------------------------
92// Inner state (behind RwLock)
93// ---------------------------------------------------------------------------
94
95#[cfg_attr(
96    feature = "persistence",
97    derive(serde::Serialize, serde::Deserialize)
98)]
99pub(crate) struct IndexInner {
100    pub(crate) vectors: HashMap<usize, Array1<f32>>,
101    pub(crate) tables: Vec<HashMap<u64, Vec<usize>>>,
102    pub(crate) hashers: Vec<RandomProjectionHasher>,
103    pub(crate) config: IndexConfig,
104    pub(crate) next_id: usize,
105}
106
107// ---------------------------------------------------------------------------
108// LshIndex
109// ---------------------------------------------------------------------------
110
111/// A locality-sensitive hashing index for approximate nearest-neighbor search.
112///
113/// Thread-safe: concurrent reads (queries) proceed in parallel; writes
114/// (inserts, removes) acquire exclusive access via `parking_lot::RwLock`.
115pub struct LshIndex {
116    pub(crate) inner: RwLock<IndexInner>,
117    pub(crate) metrics: Option<Arc<MetricsCollector>>,
118}
119
120impl std::fmt::Debug for LshIndex {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        let inner = self.inner.read();
123        f.debug_struct("LshIndex")
124            .field("num_vectors", &inner.vectors.len())
125            .field("config", &inner.config)
126            .field("has_metrics", &self.metrics.is_some())
127            .finish()
128    }
129}
130
131impl LshIndex {
132    /// Start building an index with the builder pattern.
133    pub fn builder() -> LshIndexBuilder {
134        LshIndexBuilder::new()
135    }
136
137    /// Create an index directly from an [`IndexConfig`].
138    pub fn new(config: IndexConfig) -> Result<Self> {
139        Self::new_with_metrics(config, false)
140    }
141
142    fn new_with_metrics(config: IndexConfig, enable_metrics: bool) -> Result<Self> {
143        if config.dim == 0 {
144            return Err(LshError::ZeroDimension);
145        }
146        if config.num_hashes == 0 || config.num_hashes > 64 {
147            return Err(LshError::InvalidNumHashes(config.num_hashes));
148        }
149        if config.num_tables == 0 {
150            return Err(LshError::InvalidConfig(
151                "num_tables must be > 0".into(),
152            ));
153        }
154
155        let mut rng = match config.seed {
156            Some(seed) => StdRng::seed_from_u64(seed),
157            None => StdRng::from_entropy(),
158        };
159
160        let hashers: Vec<RandomProjectionHasher> = (0..config.num_tables)
161            .map(|_| RandomProjectionHasher::new(config.dim, config.num_hashes, &mut rng))
162            .collect();
163
164        let tables = (0..config.num_tables).map(|_| HashMap::new()).collect();
165
166        let inner = IndexInner {
167            vectors: HashMap::new(),
168            tables,
169            hashers,
170            config,
171            next_id: 0,
172        };
173
174        let metrics = if enable_metrics {
175            Some(Arc::new(MetricsCollector::new()))
176        } else {
177            None
178        };
179
180        Ok(Self {
181            inner: RwLock::new(inner),
182            metrics,
183        })
184    }
185
186    // ------------------------------------------------------------------
187    // Insertion
188    // ------------------------------------------------------------------
189
190    /// Insert a vector with the given ID.
191    ///
192    /// If a vector with this ID already exists it is silently replaced.
193    pub fn insert(&self, id: usize, vector: &[f32]) -> Result<()> {
194        let mut inner = self.inner.write();
195
196        if vector.len() != inner.config.dim {
197            return Err(LshError::DimensionMismatch {
198                expected: inner.config.dim,
199                got: vector.len(),
200            });
201        }
202
203        // If the id already exists, remove old hashes first.
204        if let Some(old_vec) = inner.vectors.get(&id) {
205            let old_vec = old_vec.clone();
206            let old_hashes: Vec<u64> = inner
207                .hashers
208                .iter()
209                .map(|h| h.hash_vector_fast(&old_vec.view()))
210                .collect();
211            for (i, old_hash) in old_hashes.into_iter().enumerate() {
212                if let Some(bucket) = inner.tables[i].get_mut(&old_hash) {
213                    bucket.retain(|&x| x != id);
214                    if bucket.is_empty() {
215                        inner.tables[i].remove(&old_hash);
216                    }
217                }
218            }
219        }
220
221        let mut arr = Array1::from_vec(vector.to_vec());
222        if inner.config.normalize_vectors {
223            distance::normalize(&mut arr);
224        }
225
226        let new_hashes: Vec<u64> = inner
227            .hashers
228            .iter()
229            .map(|h| h.hash_vector_fast(&arr.view()))
230            .collect();
231        for (i, hash) in new_hashes.into_iter().enumerate() {
232            inner.tables[i].entry(hash).or_default().push(id);
233        }
234
235        inner.vectors.insert(id, arr);
236
237        if id >= inner.next_id {
238            inner.next_id = id + 1;
239        }
240
241        if let Some(ref m) = self.metrics {
242            m.record_insert();
243        }
244
245        Ok(())
246    }
247
248    /// Insert a vector and receive an auto-assigned ID.
249    ///
250    /// The ID is assigned atomically under the write lock, so concurrent
251    /// calls will never produce duplicate IDs.
252    pub fn insert_auto(&self, vector: &[f32]) -> Result<usize> {
253        let mut inner = self.inner.write();
254
255        if vector.len() != inner.config.dim {
256            return Err(LshError::DimensionMismatch {
257                expected: inner.config.dim,
258                got: vector.len(),
259            });
260        }
261
262        let id = inner.next_id;
263
264        let mut arr = Array1::from_vec(vector.to_vec());
265        if inner.config.normalize_vectors {
266            distance::normalize(&mut arr);
267        }
268
269        let new_hashes: Vec<u64> = inner
270            .hashers
271            .iter()
272            .map(|h| h.hash_vector_fast(&arr.view()))
273            .collect();
274        for (i, hash) in new_hashes.into_iter().enumerate() {
275            inner.tables[i].entry(hash).or_default().push(id);
276        }
277
278        inner.vectors.insert(id, arr);
279        inner.next_id = id + 1;
280
281        if let Some(ref m) = self.metrics {
282            m.record_insert();
283        }
284
285        Ok(id)
286    }
287
288    /// Insert multiple vectors at once. Aborts on first error.
289    pub fn insert_batch(&self, vectors: &[(usize, &[f32])]) -> Result<()> {
290        for &(id, v) in vectors {
291            self.insert(id, v)?;
292        }
293        Ok(())
294    }
295
296    // ------------------------------------------------------------------
297    // Query
298    // ------------------------------------------------------------------
299
300    /// Find the `k` approximate nearest neighbors for `vector`.
301    ///
302    /// Returns results sorted by ascending distance (closest first).
303    pub fn query(&self, vector: &[f32], k: usize) -> Result<Vec<QueryResult>> {
304        let timer = self.metrics.as_ref().map(|_| QueryTimer::new());
305        let inner = self.inner.read();
306
307        if vector.len() != inner.config.dim {
308            return Err(LshError::DimensionMismatch {
309                expected: inner.config.dim,
310                got: vector.len(),
311            });
312        }
313
314        if inner.vectors.is_empty() {
315            return Ok(Vec::new());
316        }
317
318        let mut query_vec = Array1::from_vec(vector.to_vec());
319        if inner.config.normalize_vectors {
320            distance::normalize(&mut query_vec);
321        }
322
323        // Collect candidate IDs across all tables.
324        // Use a bitvec for O(1) dedup when IDs are dense sequential integers,
325        // falling back to HashMap when IDs are sparse (next_id > 4 * num_vectors).
326        let num_vectors = inner.vectors.len();
327        let use_bitvec = inner.next_id <= num_vectors.saturating_mul(4);
328        let mut seen = if use_bitvec {
329            vec![false; inner.next_id]
330        } else {
331            Vec::new()
332        };
333        let mut candidate_set: HashMap<usize, ()> = if use_bitvec {
334            HashMap::new() // unused, but need a binding
335        } else {
336            HashMap::with_capacity(num_vectors / 4)
337        };
338        let mut candidate_ids: Vec<usize> = Vec::new();
339
340        for (i, hasher) in inner.hashers.iter().enumerate() {
341            let (hash, margins) = hasher.hash_vector(&query_vec.view());
342
343            let probe_keys = if inner.config.num_probes > 0 {
344                multi_probe_keys(hash, &margins, inner.config.num_probes)
345            } else {
346                vec![hash]
347            };
348
349            for key in probe_keys {
350                if let Some(bucket) = inner.tables[i].get(&key) {
351                    if let Some(ref m) = self.metrics {
352                        m.record_bucket_hit();
353                    }
354                    for &id in bucket {
355                        if use_bitvec {
356                            if !seen[id] {
357                                seen[id] = true;
358                                candidate_ids.push(id);
359                            }
360                        } else if candidate_set.insert(id, ()).is_none() {
361                            candidate_ids.push(id);
362                        }
363                    }
364                } else if let Some(ref m) = self.metrics {
365                    m.record_bucket_miss();
366                }
367            }
368        }
369
370        // Exact re-ranking of candidates.
371        // When vectors are pre-normalized (cosine mode), use the fast 1-dot path
372        // which avoids two redundant norm computations per candidate.
373        let use_fast_cosine = inner.config.normalize_vectors
374            && inner.config.distance_metric == distance::DistanceMetric::Cosine;
375        let query_view = query_vec.view();
376
377        let num_candidates = candidate_ids.len();
378
379        let mut results: Vec<QueryResult> = candidate_ids
380            .iter()
381            .filter_map(|&id| {
382                inner.vectors.get(&id).map(|stored| {
383                    let dist = if use_fast_cosine {
384                        distance::cosine_distance_normalized(&query_view, &stored.view())
385                    } else {
386                        inner
387                            .config
388                            .distance_metric
389                            .compute(&query_view, &stored.view())
390                    };
391                    QueryResult { id, distance: dist }
392                })
393            })
394            .collect();
395
396        results.sort_by(|a, b| {
397            a.distance
398                .partial_cmp(&b.distance)
399                .unwrap_or(std::cmp::Ordering::Equal)
400        });
401        results.truncate(k);
402
403        if let Some(ref m) = self.metrics {
404            if let Some(t) = timer {
405                m.record_query(num_candidates as u64, t.elapsed_ns());
406            }
407        }
408
409        Ok(results)
410    }
411
412    // ------------------------------------------------------------------
413    // Removal / lookup
414    // ------------------------------------------------------------------
415
416    /// Remove a vector by ID.
417    pub fn remove(&self, id: usize) -> Result<()> {
418        let mut inner = self.inner.write();
419
420        let vec = inner.vectors.remove(&id).ok_or(LshError::NotFound(id))?;
421
422        let hashes: Vec<u64> = inner
423            .hashers
424            .iter()
425            .map(|h| h.hash_vector_fast(&vec.view()))
426            .collect();
427        for (i, hash) in hashes.into_iter().enumerate() {
428            if let Some(bucket) = inner.tables[i].get_mut(&hash) {
429                bucket.retain(|&x| x != id);
430                if bucket.is_empty() {
431                    inner.tables[i].remove(&hash);
432                }
433            }
434        }
435
436        Ok(())
437    }
438
439    /// Check whether a vector ID is present.
440    pub fn contains(&self, id: usize) -> bool {
441        self.inner.read().vectors.contains_key(&id)
442    }
443
444    // ------------------------------------------------------------------
445    // Stats / metrics
446    // ------------------------------------------------------------------
447
448    /// Number of stored vectors.
449    pub fn len(&self) -> usize {
450        self.inner.read().vectors.len()
451    }
452
453    /// True when the index holds no vectors.
454    pub fn is_empty(&self) -> bool {
455        self.inner.read().vectors.is_empty()
456    }
457
458    /// Compute aggregate statistics about the index.
459    pub fn stats(&self) -> IndexStats {
460        let inner = self.inner.read();
461
462        let total_buckets: usize = inner.tables.iter().map(|t| t.len()).sum();
463        let total_entries: usize = inner
464            .tables
465            .iter()
466            .flat_map(|t| t.values())
467            .map(|v| v.len())
468            .sum();
469        let max_bucket_size = inner
470            .tables
471            .iter()
472            .flat_map(|t| t.values())
473            .map(|v| v.len())
474            .max()
475            .unwrap_or(0);
476
477        let avg_bucket_size = if total_buckets > 0 {
478            total_entries as f64 / total_buckets as f64
479        } else {
480            0.0
481        };
482
483        let vector_mem =
484            inner.vectors.len() * (inner.config.dim * 4 + std::mem::size_of::<usize>());
485        let table_mem = total_buckets * (std::mem::size_of::<u64>() + 24);
486        let entry_mem = total_entries * std::mem::size_of::<usize>();
487        let proj_mem =
488            inner.config.num_tables * inner.config.num_hashes * inner.config.dim * 4;
489
490        IndexStats {
491            num_vectors: inner.vectors.len(),
492            num_tables: inner.config.num_tables,
493            num_hashes: inner.config.num_hashes,
494            dimension: inner.config.dim,
495            total_buckets,
496            avg_bucket_size,
497            max_bucket_size,
498            memory_estimate_bytes: vector_mem + table_mem + entry_mem + proj_mem,
499        }
500    }
501
502    /// Snapshot of runtime metrics (`None` if metrics were not enabled).
503    pub fn metrics(&self) -> Option<MetricsSnapshot> {
504        self.metrics.as_ref().map(|m| m.snapshot())
505    }
506
507    /// Reset metrics counters.
508    pub fn reset_metrics(&self) {
509        if let Some(ref m) = self.metrics {
510            m.reset();
511        }
512    }
513
514    /// Remove all vectors from the index (projections are preserved).
515    pub fn clear(&self) {
516        let mut inner = self.inner.write();
517        inner.vectors.clear();
518        for table in &mut inner.tables {
519            table.clear();
520        }
521        inner.next_id = 0;
522    }
523
524    /// Return a clone of the current configuration.
525    pub fn config(&self) -> IndexConfig {
526        self.inner.read().config.clone()
527    }
528}
529
530// ---------------------------------------------------------------------------
531// Parallel batch ops (behind `parallel` feature)
532// ---------------------------------------------------------------------------
533
534#[cfg(feature = "parallel")]
535impl LshIndex {
536    /// Insert many vectors using rayon for parallel hash computation.
537    pub fn par_insert_batch(&self, vectors: &[(usize, Vec<f32>)]) -> Result<()> {
538        use rayon::prelude::*;
539
540        // Snapshot config + hashers so we don't hold the lock during parallel work.
541        let (config, hashers) = {
542            let inner = self.inner.read();
543            (inner.config.clone(), inner.hashers.clone())
544        };
545
546        // Validate dimensions.
547        for (_, v) in vectors {
548            if v.len() != config.dim {
549                return Err(LshError::DimensionMismatch {
550                    expected: config.dim,
551                    got: v.len(),
552                });
553            }
554        }
555
556        // Parallel: normalise + hash.
557        let prepared: Vec<(usize, Array1<f32>, Vec<u64>)> = vectors
558            .par_iter()
559            .map(|(id, v)| {
560                let mut arr = Array1::from_vec(v.clone());
561                if config.normalize_vectors {
562                    distance::normalize(&mut arr);
563                }
564                let hashes: Vec<u64> = hashers
565                    .iter()
566                    .map(|h| h.hash_vector_fast(&arr.view()))
567                    .collect();
568                (*id, arr, hashes)
569            })
570            .collect();
571
572        // Sequential: write to shared state.
573        let mut inner = self.inner.write();
574        for (id, arr, hashes) in prepared {
575            // If the id already exists, remove old hashes first (same as insert).
576            if let Some(old_vec) = inner.vectors.get(&id) {
577                let old_vec = old_vec.clone();
578                let old_hashes: Vec<u64> = hashers
579                    .iter()
580                    .map(|h| h.hash_vector_fast(&old_vec.view()))
581                    .collect();
582                for (i, old_hash) in old_hashes.into_iter().enumerate() {
583                    if let Some(bucket) = inner.tables[i].get_mut(&old_hash) {
584                        bucket.retain(|&x| x != id);
585                        if bucket.is_empty() {
586                            inner.tables[i].remove(&old_hash);
587                        }
588                    }
589                }
590            }
591
592            for (i, hash) in hashes.into_iter().enumerate() {
593                inner.tables[i].entry(hash).or_default().push(id);
594            }
595            inner.vectors.insert(id, arr);
596            if id >= inner.next_id {
597                inner.next_id = id + 1;
598            }
599        }
600
601        Ok(())
602    }
603
604    /// Query multiple vectors in parallel.
605    pub fn par_query_batch(
606        &self,
607        queries: &[Vec<f32>],
608        k: usize,
609    ) -> Result<Vec<Vec<QueryResult>>> {
610        use rayon::prelude::*;
611
612        queries
613            .par_iter()
614            .map(|q| self.query(q, k))
615            .collect()
616    }
617}
618
619// ---------------------------------------------------------------------------
620// Builder
621// ---------------------------------------------------------------------------
622
623/// Fluent builder for [`LshIndex`].
624#[derive(Default)]
625pub struct LshIndexBuilder {
626    config: IndexConfig,
627    enable_metrics: bool,
628}
629
630impl LshIndexBuilder {
631    pub fn new() -> Self {
632        Self::default()
633    }
634
635    pub fn dim(mut self, dim: usize) -> Self {
636        self.config.dim = dim;
637        self
638    }
639
640    pub fn num_hashes(mut self, n: usize) -> Self {
641        self.config.num_hashes = n;
642        self
643    }
644
645    pub fn num_tables(mut self, n: usize) -> Self {
646        self.config.num_tables = n;
647        self
648    }
649
650    pub fn num_probes(mut self, n: usize) -> Self {
651        self.config.num_probes = n;
652        self
653    }
654
655    pub fn distance_metric(mut self, m: DistanceMetric) -> Self {
656        self.config.distance_metric = m;
657        self
658    }
659
660    pub fn normalize(mut self, yes: bool) -> Self {
661        self.config.normalize_vectors = yes;
662        self
663    }
664
665    pub fn seed(mut self, seed: u64) -> Self {
666        self.config.seed = Some(seed);
667        self
668    }
669
670    pub fn enable_metrics(mut self) -> Self {
671        self.enable_metrics = true;
672        self
673    }
674
675    /// Build the index, returning an error on invalid configuration.
676    pub fn build(self) -> Result<LshIndex> {
677        LshIndex::new_with_metrics(self.config, self.enable_metrics)
678    }
679}