Skip to main content

sbom_tools/matching/
lsh.rs

1//! Locality-Sensitive Hashing (LSH) for approximate nearest neighbor search.
2//!
3//! This module provides `MinHash` LSH for efficient similarity search on large SBOMs
4//! (10,000+ components). It trades some accuracy for dramatic speed improvements
5//! by using hash-based approximate matching.
6//!
7//! # How it works
8//!
9//! 1. Each component name is converted to a set of character shingles (n-grams)
10//! 2. `MinHash` signatures are computed for each shingle set
11//! 3. Signatures are divided into bands and hashed into buckets
12//! 4. Components in the same bucket are candidate matches
13//!
14//! # Performance
15//!
16//! - Build time: O(n × k) where k = signature size
17//! - Query time: O(1) average for bucket lookup + O(m) for candidates
18//! - Space: O(n × k) for signatures
19
20use super::index::ComponentIndex;
21use crate::model::{CanonicalId, Component, NormalizedSbom};
22use std::collections::{HashMap, HashSet};
23use std::hash::{Hash, Hasher};
24
25/// Configuration for LSH index.
26#[derive(Debug, Clone)]
27pub struct LshConfig {
28    /// Number of hash functions in the `MinHash` signature
29    pub num_hashes: usize,
30    /// Number of bands to divide the signature into
31    pub num_bands: usize,
32    /// Size of character shingles (n-grams)
33    pub shingle_size: usize,
34    /// Minimum Jaccard similarity threshold this config is tuned for
35    pub target_threshold: f64,
36    /// Include ecosystem as a token in shingles (improves grouping by ecosystem)
37    pub include_ecosystem_token: bool,
38    /// Include group/namespace as a token in shingles (useful for Maven, npm scopes)
39    pub include_group_token: bool,
40}
41
42impl LshConfig {
43    /// Create a config tuned for the given similarity threshold.
44    ///
45    /// The number of bands and rows are chosen to maximize the probability
46    /// of finding pairs with similarity >= threshold while minimizing false positives.
47    #[must_use]
48    pub fn for_threshold(threshold: f64) -> Self {
49        // For threshold t, optimal parameters satisfy: t ≈ (1/b)^(1/r)
50        // where b = bands, r = rows per band, and b × r = num_hashes
51        //
52        // Common configurations:
53        // - t=0.5: b=20, r=5 (100 hashes)
54        // - t=0.8: b=50, r=2 (100 hashes)
55        // - t=0.9: b=90, r=1 (90 hashes) - but this is just exact bucketing
56
57        let (num_bands, rows_per_band) = if threshold >= 0.9 {
58            (50, 2) // 100 hashes, catches ~90%+ similar
59        } else if threshold >= 0.8 {
60            (25, 4) // 100 hashes, catches ~80%+ similar
61        } else if threshold >= 0.7 {
62            (20, 5) // 100 hashes, catches ~70%+ similar
63        } else if threshold >= 0.5 {
64            (10, 10) // 100 hashes, catches ~50%+ similar
65        } else {
66            (5, 20) // 100 hashes, very permissive
67        };
68
69        Self {
70            num_hashes: num_bands * rows_per_band,
71            num_bands,
72            shingle_size: 3, // Trigrams work well for package names
73            target_threshold: threshold,
74            include_ecosystem_token: true, // Helps group by ecosystem
75            include_group_token: false,    // Optional, disabled by default
76        }
77    }
78
79    /// Default config for balanced matching (~0.8 threshold).
80    #[must_use]
81    pub fn default_balanced() -> Self {
82        Self::for_threshold(0.8)
83    }
84
85    /// Config for strict matching (~0.9 threshold).
86    #[must_use]
87    pub fn strict() -> Self {
88        Self::for_threshold(0.9)
89    }
90
91    /// Config for permissive matching (~0.5 threshold).
92    #[must_use]
93    pub fn permissive() -> Self {
94        Self::for_threshold(0.5)
95    }
96
97    /// Get rows per band (signature elements per band).
98    #[must_use]
99    pub const fn rows_per_band(&self) -> usize {
100        self.num_hashes / self.num_bands
101    }
102}
103
104impl Default for LshConfig {
105    fn default() -> Self {
106        Self::default_balanced()
107    }
108}
109
110/// `MinHash` signature for a component.
111#[derive(Debug, Clone)]
112pub struct MinHashSignature {
113    /// The hash values (one per hash function)
114    pub values: Vec<u64>,
115}
116
117impl MinHashSignature {
118    /// Compute the estimated Jaccard similarity between two signatures.
119    #[must_use]
120    pub fn estimated_similarity(&self, other: &Self) -> f64 {
121        if self.values.len() != other.values.len() {
122            return 0.0;
123        }
124
125        let matching = self
126            .values
127            .iter()
128            .zip(other.values.iter())
129            .filter(|(a, b)| a == b)
130            .count();
131
132        matching as f64 / self.values.len() as f64
133    }
134}
135
136/// LSH index for efficient approximate nearest neighbor search.
137pub struct LshIndex {
138    /// Configuration
139    config: LshConfig,
140    /// `MinHash` signatures for each component
141    signatures: HashMap<CanonicalId, MinHashSignature>,
142    /// Band buckets: `band_index` -> `bucket_hash` -> component IDs
143    buckets: Vec<HashMap<u64, Vec<CanonicalId>>>,
144    /// Hash coefficients for `MinHash` (a, b pairs for h(x) = (ax + b) mod p)
145    hash_coeffs: Vec<(u64, u64)>,
146    /// Large prime for hashing
147    prime: u64,
148}
149
150impl LshIndex {
151    /// Create a new LSH index with the given configuration.
152    #[must_use]
153    pub fn new(config: LshConfig) -> Self {
154        use std::collections::hash_map::RandomState;
155        use std::hash::BuildHasher;
156
157        // Generate random hash coefficients
158        let mut hash_coeffs = Vec::with_capacity(config.num_hashes);
159        let random_state = RandomState::new();
160
161        for i in 0..config.num_hashes {
162            let a = random_state.hash_one(i as u64 * 31337) | 1; // Ensure odd (coprime with 2^64)
163
164            let b = random_state.hash_one(i as u64 * 7919 + 12345);
165
166            hash_coeffs.push((a, b));
167        }
168
169        // Initialize empty buckets for each band
170        let buckets = (0..config.num_bands)
171            .map(|_| HashMap::with_capacity(64))
172            .collect();
173
174        Self {
175            config,
176            signatures: HashMap::with_capacity(256),
177            buckets,
178            hash_coeffs,
179            prime: 0xFFFF_FFFF_FFFF_FFC5, // Large prime close to 2^64
180        }
181    }
182
183    /// Build an LSH index from an SBOM.
184    #[must_use]
185    pub fn build(sbom: &NormalizedSbom, config: LshConfig) -> Self {
186        let mut index = Self::new(config);
187
188        for (id, comp) in &sbom.components {
189            index.insert(id.clone(), comp);
190        }
191
192        index
193    }
194
195    /// Insert a component into the index.
196    pub fn insert(&mut self, id: CanonicalId, component: &Component) {
197        // Compute shingles from the component (uses ecosystem-aware normalization)
198        let shingles = self.compute_shingles(component);
199
200        // Compute MinHash signature
201        let signature = self.compute_minhash(&shingles);
202
203        // Insert into band buckets
204        self.insert_into_buckets(&id, &signature);
205
206        // Store signature
207        self.signatures.insert(id, signature);
208    }
209
210    /// Find candidate matches for a component.
211    ///
212    /// Returns component IDs that are likely similar based on LSH buckets.
213    /// These candidates should be verified with exact similarity computation.
214    #[must_use]
215    pub fn find_candidates(&self, component: &Component) -> Vec<CanonicalId> {
216        let shingles = self.compute_shingles(component);
217        let signature = self.compute_minhash(&shingles);
218
219        self.find_candidates_by_signature(&signature)
220    }
221
222    /// Find candidates using a pre-computed signature.
223    #[must_use]
224    pub fn find_candidates_by_signature(&self, signature: &MinHashSignature) -> Vec<CanonicalId> {
225        let mut candidates = HashSet::new();
226        let rows_per_band = self.config.rows_per_band();
227
228        for (band_idx, bucket_map) in self.buckets.iter().enumerate() {
229            let band_hash = self.hash_band(signature, band_idx, rows_per_band);
230
231            if let Some(ids) = bucket_map.get(&band_hash) {
232                for id in ids {
233                    candidates.insert(id.clone());
234                }
235            }
236        }
237
238        candidates.into_iter().collect()
239    }
240
241    /// Find candidates for a component from another index.
242    ///
243    /// Useful for diffing: build index from new SBOM, query with old SBOM components.
244    pub fn find_candidates_for_id(&self, id: &CanonicalId) -> Vec<CanonicalId> {
245        self.signatures.get(id).map_or_else(Vec::new, |signature| {
246            self.find_candidates_by_signature(signature)
247        })
248    }
249
250    /// Get the `MinHash` signature for a component.
251    #[must_use]
252    pub fn get_signature(&self, id: &CanonicalId) -> Option<&MinHashSignature> {
253        self.signatures.get(id)
254    }
255
256    /// Estimate similarity between two components in the index.
257    #[must_use]
258    pub fn estimate_similarity(&self, id_a: &CanonicalId, id_b: &CanonicalId) -> Option<f64> {
259        let sig_a = self.signatures.get(id_a)?;
260        let sig_b = self.signatures.get(id_b)?;
261        Some(sig_a.estimated_similarity(sig_b))
262    }
263
264    /// Get statistics about the index.
265    pub fn stats(&self) -> LshIndexStats {
266        let total_components = self.signatures.len();
267        let total_buckets: usize = self
268            .buckets
269            .iter()
270            .map(std::collections::HashMap::len)
271            .sum();
272        let max_bucket_size = self
273            .buckets
274            .iter()
275            .flat_map(|b| b.values())
276            .map(std::vec::Vec::len)
277            .max()
278            .unwrap_or(0);
279        let avg_bucket_size = if total_buckets > 0 {
280            self.buckets
281                .iter()
282                .flat_map(|b| b.values())
283                .map(std::vec::Vec::len)
284                .sum::<usize>() as f64
285                / total_buckets as f64
286        } else {
287            0.0
288        };
289
290        LshIndexStats {
291            total_components,
292            num_bands: self.config.num_bands,
293            num_hashes: self.config.num_hashes,
294            total_buckets,
295            max_bucket_size,
296            avg_bucket_size,
297        }
298    }
299
300    /// Compute character shingles (n-grams) from a component.
301    ///
302    /// Uses ecosystem-aware normalization from `ComponentIndex` for consistent
303    /// shingling across `PyPI`, Cargo, npm, etc. Also adds optional ecosystem
304    /// and group tokens to improve candidate grouping.
305    fn compute_shingles(&self, component: &Component) -> HashSet<u64> {
306        // Get ecosystem for normalization
307        let ecosystem = component
308            .ecosystem
309            .as_ref()
310            .map(std::string::ToString::to_string);
311        let ecosystem_str = ecosystem.as_deref();
312
313        // Use ComponentIndex's normalization for consistency
314        let normalized = ComponentIndex::normalize_name(&component.name, ecosystem_str);
315        let chars: Vec<char> = normalized.chars().collect();
316
317        // Estimate capacity: roughly (len - shingle_size + 1) shingles + 2 tokens
318        let estimated_shingles = chars.len().saturating_sub(self.config.shingle_size) + 3;
319        let mut shingles = HashSet::with_capacity(estimated_shingles);
320
321        // Compute name shingles
322        if chars.len() < self.config.shingle_size {
323            // For very short names, use the whole name as a shingle
324            let mut hasher = std::collections::hash_map::DefaultHasher::new();
325            normalized.hash(&mut hasher);
326            shingles.insert(hasher.finish());
327        } else {
328            // Hash character windows directly without allocating intermediate strings
329            for window in chars.windows(self.config.shingle_size) {
330                let mut hasher = std::collections::hash_map::DefaultHasher::new();
331                window.hash(&mut hasher);
332                shingles.insert(hasher.finish());
333            }
334        }
335
336        // Add ecosystem token (helps group components by ecosystem)
337        if self.config.include_ecosystem_token
338            && let Some(ref eco) = ecosystem
339        {
340            let mut hasher = std::collections::hash_map::DefaultHasher::new();
341            "__eco:".hash(&mut hasher);
342            eco.to_lowercase().hash(&mut hasher);
343            shingles.insert(hasher.finish());
344        }
345
346        // Add group/namespace token (useful for Maven group IDs, npm scopes)
347        if self.config.include_group_token
348            && let Some(ref group) = component.group
349        {
350            let mut hasher = std::collections::hash_map::DefaultHasher::new();
351            "__grp:".hash(&mut hasher);
352            group.to_lowercase().hash(&mut hasher);
353            shingles.insert(hasher.finish());
354        }
355
356        shingles
357    }
358
359    /// Compute `MinHash` signature from shingles.
360    fn compute_minhash(&self, shingles: &HashSet<u64>) -> MinHashSignature {
361        let mut min_hashes = vec![u64::MAX; self.config.num_hashes];
362
363        for &shingle in shingles {
364            for (i, &(a, b)) in self.hash_coeffs.iter().enumerate() {
365                // h_i(x) = (a*x + b) mod prime
366                let hash = a.wrapping_mul(shingle).wrapping_add(b) % self.prime;
367                if hash < min_hashes[i] {
368                    min_hashes[i] = hash;
369                }
370            }
371        }
372
373        MinHashSignature { values: min_hashes }
374    }
375
376    /// Insert a signature into band buckets.
377    fn insert_into_buckets(&mut self, id: &CanonicalId, signature: &MinHashSignature) {
378        let rows_per_band = self.config.rows_per_band();
379
380        // Pre-compute all band hashes to avoid borrow conflicts
381        let band_hashes: Vec<u64> = (0..self.config.num_bands)
382            .map(|band_idx| self.hash_band(signature, band_idx, rows_per_band))
383            .collect();
384
385        for (band_idx, bucket_map) in self.buckets.iter_mut().enumerate() {
386            bucket_map
387                .entry(band_hashes[band_idx])
388                .or_default()
389                .push(id.clone());
390        }
391    }
392
393    /// Hash a band of the signature.
394    fn hash_band(
395        &self,
396        signature: &MinHashSignature,
397        band_idx: usize,
398        rows_per_band: usize,
399    ) -> u64 {
400        let start = band_idx * rows_per_band;
401        let end = (start + rows_per_band).min(signature.values.len());
402
403        let mut hasher = std::collections::hash_map::DefaultHasher::new();
404        for &value in &signature.values[start..end] {
405            value.hash(&mut hasher);
406        }
407        hasher.finish()
408    }
409}
410
411/// Statistics about an LSH index.
412#[derive(Debug, Clone)]
413pub struct LshIndexStats {
414    /// Total number of indexed components
415    pub total_components: usize,
416    /// Number of bands
417    pub num_bands: usize,
418    /// Total number of hash functions
419    pub num_hashes: usize,
420    /// Total number of non-empty buckets
421    pub total_buckets: usize,
422    /// Maximum components in a single bucket
423    pub max_bucket_size: usize,
424    /// Average components per bucket
425    pub avg_bucket_size: f64,
426}
427
428impl std::fmt::Display for LshIndexStats {
429    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
430        write!(
431            f,
432            "LSH Index: {} components, {} bands × {} hashes, {} buckets (max: {}, avg: {:.1})",
433            self.total_components,
434            self.num_bands,
435            self.num_hashes / self.num_bands,
436            self.total_buckets,
437            self.max_bucket_size,
438            self.avg_bucket_size
439        )
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use crate::model::DocumentMetadata;
447
448    fn make_component(name: &str) -> Component {
449        Component::new(name.to_string(), format!("id-{}", name))
450    }
451
452    #[test]
453    fn test_lsh_config_for_threshold() {
454        let config = LshConfig::for_threshold(0.8);
455        assert_eq!(config.num_hashes, 100);
456        assert!(config.num_bands > 0);
457        assert_eq!(config.num_hashes, config.num_bands * config.rows_per_band());
458    }
459
460    #[test]
461    fn test_minhash_signature_similarity() {
462        let sig_a = MinHashSignature {
463            values: vec![1, 2, 3, 4, 5],
464        };
465        let sig_b = MinHashSignature {
466            values: vec![1, 2, 3, 4, 5],
467        };
468        assert_eq!(sig_a.estimated_similarity(&sig_b), 1.0);
469
470        let sig_c = MinHashSignature {
471            values: vec![1, 2, 3, 6, 7],
472        };
473        assert!((sig_a.estimated_similarity(&sig_c) - 0.6).abs() < 0.01);
474    }
475
476    #[test]
477    fn test_lsh_index_build() {
478        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
479        sbom.add_component(make_component("lodash"));
480        sbom.add_component(make_component("lodash-es"));
481        sbom.add_component(make_component("underscore"));
482        sbom.add_component(make_component("react"));
483
484        let index = LshIndex::build(&sbom, LshConfig::default_balanced());
485        let stats = index.stats();
486
487        assert_eq!(stats.total_components, 4);
488        assert!(stats.total_buckets > 0);
489    }
490
491    #[test]
492    fn test_lsh_finds_similar_names() {
493        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
494        sbom.add_component(make_component("lodash"));
495        sbom.add_component(make_component("lodash-es"));
496        sbom.add_component(make_component("lodash-fp"));
497        sbom.add_component(make_component("react"));
498        sbom.add_component(make_component("angular"));
499
500        let index = LshIndex::build(&sbom, LshConfig::for_threshold(0.5));
501
502        // Query for similar to "lodash"
503        let query = make_component("lodash");
504        let candidates = index.find_candidates(&query);
505
506        // Should find lodash variants as candidates
507        // Note: LSH is probabilistic, so we check for likely outcomes
508        assert!(
509            !candidates.is_empty(),
510            "Should find at least some candidates"
511        );
512    }
513
514    #[test]
515    fn test_lsh_signature_estimation() {
516        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
517
518        let comp1 = make_component("lodash");
519        let comp2 = make_component("lodash-es");
520        let comp3 = make_component("completely-different-name");
521
522        let id1 = comp1.canonical_id.clone();
523        let id2 = comp2.canonical_id.clone();
524        let id3 = comp3.canonical_id.clone();
525
526        sbom.add_component(comp1);
527        sbom.add_component(comp2);
528        sbom.add_component(comp3);
529
530        let index = LshIndex::build(&sbom, LshConfig::default_balanced());
531
532        // Similar names should have higher estimated similarity
533        let sim_12 = index.estimate_similarity(&id1, &id2).unwrap();
534        let sim_13 = index.estimate_similarity(&id1, &id3).unwrap();
535
536        assert!(
537            sim_12 > sim_13,
538            "lodash vs lodash-es ({:.2}) should be more similar than lodash vs completely-different ({:.2})",
539            sim_12,
540            sim_13
541        );
542    }
543
544    #[test]
545    fn test_lsh_index_stats() {
546        let config = LshConfig::for_threshold(0.8);
547        let index = LshIndex::new(config);
548
549        let stats = index.stats();
550        assert_eq!(stats.total_components, 0);
551        assert_eq!(stats.num_bands, 25);
552        assert_eq!(stats.num_hashes, 100);
553    }
554}