Skip to main content

superbit/
index.rs

1use std::sync::Arc;
2
3use hashbrown::HashMap;
4use ndarray::Array1;
5use parking_lot::RwLock;
6use rand::rngs::StdRng;
7use rand::SeedableRng;
8
9use crate::distance::{self, DistanceMetric};
10use crate::error::{LshError, Result};
11use crate::hash::{multi_probe_keys, RandomProjectionHasher};
12use crate::metrics::{MetricsCollector, MetricsSnapshot, QueryTimer};
13
14/// Configuration for the LSH index.
15#[derive(Debug, Clone)]
16#[cfg_attr(
17    feature = "persistence",
18    derive(serde::Serialize, serde::Deserialize)
19)]
20pub struct IndexConfig {
21    /// Dimensionality of vectors.
22    pub dim: usize,
23    /// Number of hash bits per table (1..=64).
24    pub num_hashes: usize,
25    /// Number of independent hash tables.
26    pub num_tables: usize,
27    /// Number of extra buckets to probe per table during queries.
28    pub num_probes: usize,
29    /// Distance metric for ranking candidates.
30    pub distance_metric: DistanceMetric,
31    /// Whether to L2-normalize vectors on insertion (recommended for cosine).
32    pub normalize_vectors: bool,
33    /// Optional RNG seed for reproducible projections.
34    pub seed: Option<u64>,
35}
36
37impl Default for IndexConfig {
38    fn default() -> Self {
39        Self {
40            dim: 768,
41            num_hashes: 8,
42            num_tables: 16,
43            num_probes: 3,
44            distance_metric: DistanceMetric::Cosine,
45            normalize_vectors: true,
46            seed: None,
47        }
48    }
49}
50
51/// A single nearest-neighbor result.
52#[derive(Debug, Clone)]
53pub struct QueryResult {
54    /// The vector ID.
55    pub id: usize,
56    /// Distance from the query vector (lower is closer).
57    pub distance: f32,
58}
59
60/// Aggregate statistics about the index.
61#[derive(Debug, Clone)]
62pub struct IndexStats {
63    pub num_vectors: usize,
64    pub num_tables: usize,
65    pub num_hashes: usize,
66    pub dimension: usize,
67    pub total_buckets: usize,
68    pub avg_bucket_size: f64,
69    pub max_bucket_size: usize,
70    pub memory_estimate_bytes: usize,
71}
72
73impl std::fmt::Display for IndexStats {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        write!(
76            f,
77            "LshIndex {{ vectors: {}, tables: {}, hashes/table: {}, dim: {}, \
78             buckets: {}, avg_bucket: {:.1}, max_bucket: {}, mem: ~{:.1}MB }}",
79            self.num_vectors,
80            self.num_tables,
81            self.num_hashes,
82            self.dimension,
83            self.total_buckets,
84            self.avg_bucket_size,
85            self.max_bucket_size,
86            self.memory_estimate_bytes as f64 / (1024.0 * 1024.0),
87        )
88    }
89}
90
91// ---------------------------------------------------------------------------
92// Inner state (behind RwLock)
93// ---------------------------------------------------------------------------
94
95#[cfg_attr(
96    feature = "persistence",
97    derive(serde::Serialize, serde::Deserialize)
98)]
99pub(crate) struct IndexInner {
100    pub(crate) vectors: HashMap<usize, Array1<f32>>,
101    pub(crate) tables: Vec<HashMap<u64, Vec<usize>>>,
102    pub(crate) hashers: Vec<RandomProjectionHasher>,
103    pub(crate) config: IndexConfig,
104    pub(crate) next_id: usize,
105}
106
107// ---------------------------------------------------------------------------
108// LshIndex
109// ---------------------------------------------------------------------------
110
111/// A locality-sensitive hashing index for approximate nearest-neighbor search.
112///
113/// Thread-safe: concurrent reads (queries) proceed in parallel; writes
114/// (inserts, removes) acquire exclusive access via `parking_lot::RwLock`.
115pub struct LshIndex {
116    pub(crate) inner: RwLock<IndexInner>,
117    pub(crate) metrics: Option<Arc<MetricsCollector>>,
118}
119
120impl std::fmt::Debug for LshIndex {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        let inner = self.inner.read();
123        f.debug_struct("LshIndex")
124            .field("num_vectors", &inner.vectors.len())
125            .field("config", &inner.config)
126            .field("has_metrics", &self.metrics.is_some())
127            .finish()
128    }
129}
130
131impl LshIndex {
132    /// Start building an index with the builder pattern.
133    pub fn builder() -> LshIndexBuilder {
134        LshIndexBuilder::new()
135    }
136
137    /// Create an index directly from an [`IndexConfig`].
138    pub fn new(config: IndexConfig) -> Result<Self> {
139        Self::new_with_metrics(config, false)
140    }
141
142    fn new_with_metrics(config: IndexConfig, enable_metrics: bool) -> Result<Self> {
143        if config.dim == 0 {
144            return Err(LshError::ZeroDimension);
145        }
146        if config.num_hashes == 0 || config.num_hashes > 64 {
147            return Err(LshError::InvalidNumHashes(config.num_hashes));
148        }
149        if config.num_tables == 0 {
150            return Err(LshError::InvalidConfig(
151                "num_tables must be > 0".into(),
152            ));
153        }
154
155        let mut rng = match config.seed {
156            Some(seed) => StdRng::seed_from_u64(seed),
157            None => StdRng::from_entropy(),
158        };
159
160        let hashers: Vec<RandomProjectionHasher> = (0..config.num_tables)
161            .map(|_| RandomProjectionHasher::new(config.dim, config.num_hashes, &mut rng))
162            .collect();
163
164        let tables = (0..config.num_tables).map(|_| HashMap::new()).collect();
165
166        let inner = IndexInner {
167            vectors: HashMap::new(),
168            tables,
169            hashers,
170            config,
171            next_id: 0,
172        };
173
174        let metrics = if enable_metrics {
175            Some(Arc::new(MetricsCollector::new()))
176        } else {
177            None
178        };
179
180        Ok(Self {
181            inner: RwLock::new(inner),
182            metrics,
183        })
184    }
185
186    // ------------------------------------------------------------------
187    // Insertion
188    // ------------------------------------------------------------------
189
190    /// Insert a vector with the given ID.
191    ///
192    /// If a vector with this ID already exists it is silently replaced.
193    pub fn insert(&self, id: usize, vector: &[f32]) -> Result<()> {
194        let mut inner = self.inner.write();
195
196        if vector.len() != inner.config.dim {
197            return Err(LshError::DimensionMismatch {
198                expected: inner.config.dim,
199                got: vector.len(),
200            });
201        }
202
203        // If the id already exists, remove old hashes first.
204        if let Some(old_vec) = inner.vectors.get(&id) {
205            let old_vec = old_vec.clone();
206            let old_hashes: Vec<u64> = inner
207                .hashers
208                .iter()
209                .map(|h| h.hash_vector_fast(&old_vec.view()))
210                .collect();
211            for (i, old_hash) in old_hashes.into_iter().enumerate() {
212                if let Some(bucket) = inner.tables[i].get_mut(&old_hash) {
213                    bucket.retain(|&x| x != id);
214                    if bucket.is_empty() {
215                        inner.tables[i].remove(&old_hash);
216                    }
217                }
218            }
219        }
220
221        let mut arr = Array1::from_vec(vector.to_vec());
222        if inner.config.normalize_vectors {
223            distance::normalize(&mut arr);
224        }
225
226        let new_hashes: Vec<u64> = inner
227            .hashers
228            .iter()
229            .map(|h| h.hash_vector_fast(&arr.view()))
230            .collect();
231        for (i, hash) in new_hashes.into_iter().enumerate() {
232            inner.tables[i].entry(hash).or_default().push(id);
233        }
234
235        inner.vectors.insert(id, arr);
236
237        if id >= inner.next_id {
238            inner.next_id = id + 1;
239        }
240
241        if let Some(ref m) = self.metrics {
242            m.record_insert();
243        }
244
245        Ok(())
246    }
247
248    /// Insert a vector and receive an auto-assigned ID.
249    ///
250    /// The ID is assigned atomically under the write lock, so concurrent
251    /// calls will never produce duplicate IDs.
252    pub fn insert_auto(&self, vector: &[f32]) -> Result<usize> {
253        let mut inner = self.inner.write();
254
255        if vector.len() != inner.config.dim {
256            return Err(LshError::DimensionMismatch {
257                expected: inner.config.dim,
258                got: vector.len(),
259            });
260        }
261
262        let id = inner.next_id;
263
264        let mut arr = Array1::from_vec(vector.to_vec());
265        if inner.config.normalize_vectors {
266            distance::normalize(&mut arr);
267        }
268
269        let new_hashes: Vec<u64> = inner
270            .hashers
271            .iter()
272            .map(|h| h.hash_vector_fast(&arr.view()))
273            .collect();
274        for (i, hash) in new_hashes.into_iter().enumerate() {
275            inner.tables[i].entry(hash).or_default().push(id);
276        }
277
278        inner.vectors.insert(id, arr);
279        inner.next_id = id + 1;
280
281        if let Some(ref m) = self.metrics {
282            m.record_insert();
283        }
284
285        Ok(id)
286    }
287
288    /// Insert multiple vectors at once. Aborts on first error.
289    pub fn insert_batch(&self, vectors: &[(usize, &[f32])]) -> Result<()> {
290        for &(id, v) in vectors {
291            self.insert(id, v)?;
292        }
293        Ok(())
294    }
295
296    // ------------------------------------------------------------------
297    // Query
298    // ------------------------------------------------------------------
299
300    /// Find the `k` approximate nearest neighbors for `vector`.
301    ///
302    /// Returns results sorted by ascending distance (closest first).
303    pub fn query(&self, vector: &[f32], k: usize) -> Result<Vec<QueryResult>> {
304        let timer = self.metrics.as_ref().map(|_| QueryTimer::new());
305        let inner = self.inner.read();
306
307        if vector.len() != inner.config.dim {
308            return Err(LshError::DimensionMismatch {
309                expected: inner.config.dim,
310                got: vector.len(),
311            });
312        }
313
314        if inner.vectors.is_empty() {
315            return Ok(Vec::new());
316        }
317
318        let mut query_vec = Array1::from_vec(vector.to_vec());
319        if inner.config.normalize_vectors {
320            distance::normalize(&mut query_vec);
321        }
322
323        // Collect candidate IDs across all tables.
324        let mut candidates: HashMap<usize, ()> = HashMap::new();
325
326        for (i, hasher) in inner.hashers.iter().enumerate() {
327            let (hash, margins) = hasher.hash_vector(&query_vec.view());
328
329            let probe_keys = if inner.config.num_probes > 0 {
330                multi_probe_keys(hash, &margins, inner.config.num_probes)
331            } else {
332                vec![hash]
333            };
334
335            for key in probe_keys {
336                if let Some(bucket) = inner.tables[i].get(&key) {
337                    if let Some(ref m) = self.metrics {
338                        m.record_bucket_hit();
339                    }
340                    for &id in bucket {
341                        candidates.insert(id, ());
342                    }
343                } else if let Some(ref m) = self.metrics {
344                    m.record_bucket_miss();
345                }
346            }
347        }
348
349        // Exact re-ranking of candidates.
350        let mut results: Vec<QueryResult> = candidates
351            .keys()
352            .filter_map(|&id| {
353                inner.vectors.get(&id).map(|stored| {
354                    let dist = inner
355                        .config
356                        .distance_metric
357                        .compute(&query_vec.view(), &stored.view());
358                    QueryResult { id, distance: dist }
359                })
360            })
361            .collect();
362
363        results.sort_by(|a, b| {
364            a.distance
365                .partial_cmp(&b.distance)
366                .unwrap_or(std::cmp::Ordering::Equal)
367        });
368        results.truncate(k);
369
370        if let Some(ref m) = self.metrics {
371            if let Some(t) = timer {
372                m.record_query(candidates.len() as u64, t.elapsed_ns());
373            }
374        }
375
376        Ok(results)
377    }
378
379    // ------------------------------------------------------------------
380    // Removal / lookup
381    // ------------------------------------------------------------------
382
383    /// Remove a vector by ID.
384    pub fn remove(&self, id: usize) -> Result<()> {
385        let mut inner = self.inner.write();
386
387        let vec = inner.vectors.remove(&id).ok_or(LshError::NotFound(id))?;
388
389        let hashes: Vec<u64> = inner
390            .hashers
391            .iter()
392            .map(|h| h.hash_vector_fast(&vec.view()))
393            .collect();
394        for (i, hash) in hashes.into_iter().enumerate() {
395            if let Some(bucket) = inner.tables[i].get_mut(&hash) {
396                bucket.retain(|&x| x != id);
397                if bucket.is_empty() {
398                    inner.tables[i].remove(&hash);
399                }
400            }
401        }
402
403        Ok(())
404    }
405
406    /// Check whether a vector ID is present.
407    pub fn contains(&self, id: usize) -> bool {
408        self.inner.read().vectors.contains_key(&id)
409    }
410
411    // ------------------------------------------------------------------
412    // Stats / metrics
413    // ------------------------------------------------------------------
414
415    /// Number of stored vectors.
416    pub fn len(&self) -> usize {
417        self.inner.read().vectors.len()
418    }
419
420    /// True when the index holds no vectors.
421    pub fn is_empty(&self) -> bool {
422        self.inner.read().vectors.is_empty()
423    }
424
425    /// Compute aggregate statistics about the index.
426    pub fn stats(&self) -> IndexStats {
427        let inner = self.inner.read();
428
429        let total_buckets: usize = inner.tables.iter().map(|t| t.len()).sum();
430        let total_entries: usize = inner
431            .tables
432            .iter()
433            .flat_map(|t| t.values())
434            .map(|v| v.len())
435            .sum();
436        let max_bucket_size = inner
437            .tables
438            .iter()
439            .flat_map(|t| t.values())
440            .map(|v| v.len())
441            .max()
442            .unwrap_or(0);
443
444        let avg_bucket_size = if total_buckets > 0 {
445            total_entries as f64 / total_buckets as f64
446        } else {
447            0.0
448        };
449
450        let vector_mem =
451            inner.vectors.len() * (inner.config.dim * 4 + std::mem::size_of::<usize>());
452        let table_mem = total_buckets * (std::mem::size_of::<u64>() + 24);
453        let entry_mem = total_entries * std::mem::size_of::<usize>();
454        let proj_mem =
455            inner.config.num_tables * inner.config.num_hashes * inner.config.dim * 4;
456
457        IndexStats {
458            num_vectors: inner.vectors.len(),
459            num_tables: inner.config.num_tables,
460            num_hashes: inner.config.num_hashes,
461            dimension: inner.config.dim,
462            total_buckets,
463            avg_bucket_size,
464            max_bucket_size,
465            memory_estimate_bytes: vector_mem + table_mem + entry_mem + proj_mem,
466        }
467    }
468
469    /// Snapshot of runtime metrics (`None` if metrics were not enabled).
470    pub fn metrics(&self) -> Option<MetricsSnapshot> {
471        self.metrics.as_ref().map(|m| m.snapshot())
472    }
473
474    /// Reset metrics counters.
475    pub fn reset_metrics(&self) {
476        if let Some(ref m) = self.metrics {
477            m.reset();
478        }
479    }
480
481    /// Remove all vectors from the index (projections are preserved).
482    pub fn clear(&self) {
483        let mut inner = self.inner.write();
484        inner.vectors.clear();
485        for table in &mut inner.tables {
486            table.clear();
487        }
488        inner.next_id = 0;
489    }
490
491    /// Return a clone of the current configuration.
492    pub fn config(&self) -> IndexConfig {
493        self.inner.read().config.clone()
494    }
495}
496
497// ---------------------------------------------------------------------------
498// Parallel batch ops (behind `parallel` feature)
499// ---------------------------------------------------------------------------
500
501#[cfg(feature = "parallel")]
502impl LshIndex {
503    /// Insert many vectors using rayon for parallel hash computation.
504    pub fn par_insert_batch(&self, vectors: &[(usize, Vec<f32>)]) -> Result<()> {
505        use rayon::prelude::*;
506
507        // Snapshot config + hashers so we don't hold the lock during parallel work.
508        let (config, hashers) = {
509            let inner = self.inner.read();
510            (inner.config.clone(), inner.hashers.clone())
511        };
512
513        // Validate dimensions.
514        for (_, v) in vectors {
515            if v.len() != config.dim {
516                return Err(LshError::DimensionMismatch {
517                    expected: config.dim,
518                    got: v.len(),
519                });
520            }
521        }
522
523        // Parallel: normalise + hash.
524        let prepared: Vec<(usize, Array1<f32>, Vec<u64>)> = vectors
525            .par_iter()
526            .map(|(id, v)| {
527                let mut arr = Array1::from_vec(v.clone());
528                if config.normalize_vectors {
529                    distance::normalize(&mut arr);
530                }
531                let hashes: Vec<u64> = hashers
532                    .iter()
533                    .map(|h| h.hash_vector_fast(&arr.view()))
534                    .collect();
535                (*id, arr, hashes)
536            })
537            .collect();
538
539        // Sequential: write to shared state.
540        let mut inner = self.inner.write();
541        for (id, arr, hashes) in prepared {
542            // If the id already exists, remove old hashes first (same as insert).
543            if let Some(old_vec) = inner.vectors.get(&id) {
544                let old_vec = old_vec.clone();
545                let old_hashes: Vec<u64> = hashers
546                    .iter()
547                    .map(|h| h.hash_vector_fast(&old_vec.view()))
548                    .collect();
549                for (i, old_hash) in old_hashes.into_iter().enumerate() {
550                    if let Some(bucket) = inner.tables[i].get_mut(&old_hash) {
551                        bucket.retain(|&x| x != id);
552                        if bucket.is_empty() {
553                            inner.tables[i].remove(&old_hash);
554                        }
555                    }
556                }
557            }
558
559            for (i, hash) in hashes.into_iter().enumerate() {
560                inner.tables[i].entry(hash).or_default().push(id);
561            }
562            inner.vectors.insert(id, arr);
563            if id >= inner.next_id {
564                inner.next_id = id + 1;
565            }
566        }
567
568        Ok(())
569    }
570
571    /// Query multiple vectors in parallel.
572    pub fn par_query_batch(
573        &self,
574        queries: &[Vec<f32>],
575        k: usize,
576    ) -> Result<Vec<Vec<QueryResult>>> {
577        use rayon::prelude::*;
578
579        queries
580            .par_iter()
581            .map(|q| self.query(q, k))
582            .collect()
583    }
584}
585
586// ---------------------------------------------------------------------------
587// Builder
588// ---------------------------------------------------------------------------
589
590/// Fluent builder for [`LshIndex`].
591#[derive(Default)]
592pub struct LshIndexBuilder {
593    config: IndexConfig,
594    enable_metrics: bool,
595}
596
597impl LshIndexBuilder {
598    pub fn new() -> Self {
599        Self::default()
600    }
601
602    pub fn dim(mut self, dim: usize) -> Self {
603        self.config.dim = dim;
604        self
605    }
606
607    pub fn num_hashes(mut self, n: usize) -> Self {
608        self.config.num_hashes = n;
609        self
610    }
611
612    pub fn num_tables(mut self, n: usize) -> Self {
613        self.config.num_tables = n;
614        self
615    }
616
617    pub fn num_probes(mut self, n: usize) -> Self {
618        self.config.num_probes = n;
619        self
620    }
621
622    pub fn distance_metric(mut self, m: DistanceMetric) -> Self {
623        self.config.distance_metric = m;
624        self
625    }
626
627    pub fn normalize(mut self, yes: bool) -> Self {
628        self.config.normalize_vectors = yes;
629        self
630    }
631
632    pub fn seed(mut self, seed: u64) -> Self {
633        self.config.seed = Some(seed);
634        self
635    }
636
637    pub fn enable_metrics(mut self) -> Self {
638        self.enable_metrics = true;
639        self
640    }
641
642    /// Build the index, returning an error on invalid configuration.
643    pub fn build(self) -> Result<LshIndex> {
644        LshIndex::new_with_metrics(self.config, self.enable_metrics)
645    }
646}