Skip to main content

sbom_tools/matching/
lsh.rs

1//! Locality-Sensitive Hashing (LSH) for approximate nearest neighbor search.
2//!
3//! This module provides MinHash LSH for efficient similarity search on large SBOMs
4//! (10,000+ components). It trades some accuracy for dramatic speed improvements
5//! by using hash-based approximate matching.
6//!
7//! # How it works
8//!
9//! 1. Each component name is converted to a set of character shingles (n-grams)
10//! 2. MinHash signatures are computed for each shingle set
11//! 3. Signatures are divided into bands and hashed into buckets
12//! 4. Components in the same bucket are candidate matches
13//!
14//! # Performance
15//!
16//! - Build time: O(n × k) where k = signature size
17//! - Query time: O(1) average for bucket lookup + O(m) for candidates
18//! - Space: O(n × k) for signatures
19
20use super::index::ComponentIndex;
21use crate::model::{CanonicalId, Component, NormalizedSbom};
22use std::collections::{HashMap, HashSet};
23use std::hash::{Hash, Hasher};
24
25/// Configuration for LSH index.
26#[derive(Debug, Clone)]
27pub struct LshConfig {
28    /// Number of hash functions in the MinHash signature
29    pub num_hashes: usize,
30    /// Number of bands to divide the signature into
31    pub num_bands: usize,
32    /// Size of character shingles (n-grams)
33    pub shingle_size: usize,
34    /// Minimum Jaccard similarity threshold this config is tuned for
35    pub target_threshold: f64,
36    /// Include ecosystem as a token in shingles (improves grouping by ecosystem)
37    pub include_ecosystem_token: bool,
38    /// Include group/namespace as a token in shingles (useful for Maven, npm scopes)
39    pub include_group_token: bool,
40}
41
42impl LshConfig {
43    /// Create a config tuned for the given similarity threshold.
44    ///
45    /// The number of bands and rows are chosen to maximize the probability
46    /// of finding pairs with similarity >= threshold while minimizing false positives.
47    pub fn for_threshold(threshold: f64) -> Self {
48        // For threshold t, optimal parameters satisfy: t ≈ (1/b)^(1/r)
49        // where b = bands, r = rows per band, and b × r = num_hashes
50        //
51        // Common configurations:
52        // - t=0.5: b=20, r=5 (100 hashes)
53        // - t=0.8: b=50, r=2 (100 hashes)
54        // - t=0.9: b=90, r=1 (90 hashes) - but this is just exact bucketing
55
56        let (num_bands, rows_per_band) = if threshold >= 0.9 {
57            (50, 2) // 100 hashes, catches ~90%+ similar
58        } else if threshold >= 0.8 {
59            (25, 4) // 100 hashes, catches ~80%+ similar
60        } else if threshold >= 0.7 {
61            (20, 5) // 100 hashes, catches ~70%+ similar
62        } else if threshold >= 0.5 {
63            (10, 10) // 100 hashes, catches ~50%+ similar
64        } else {
65            (5, 20) // 100 hashes, very permissive
66        };
67
68        Self {
69            num_hashes: num_bands * rows_per_band,
70            num_bands,
71            shingle_size: 3, // Trigrams work well for package names
72            target_threshold: threshold,
73            include_ecosystem_token: true,  // Helps group by ecosystem
74            include_group_token: false,     // Optional, disabled by default
75        }
76    }
77
78    /// Default config for balanced matching (~0.8 threshold).
79    pub fn default_balanced() -> Self {
80        Self::for_threshold(0.8)
81    }
82
83    /// Config for strict matching (~0.9 threshold).
84    pub fn strict() -> Self {
85        Self::for_threshold(0.9)
86    }
87
88    /// Config for permissive matching (~0.5 threshold).
89    pub fn permissive() -> Self {
90        Self::for_threshold(0.5)
91    }
92
93    /// Get rows per band (signature elements per band).
94    pub fn rows_per_band(&self) -> usize {
95        self.num_hashes / self.num_bands
96    }
97}
98
99impl Default for LshConfig {
100    fn default() -> Self {
101        Self::default_balanced()
102    }
103}
104
105/// MinHash signature for a component.
106#[derive(Debug, Clone)]
107pub struct MinHashSignature {
108    /// The hash values (one per hash function)
109    pub values: Vec<u64>,
110}
111
112impl MinHashSignature {
113    /// Compute the estimated Jaccard similarity between two signatures.
114    pub fn estimated_similarity(&self, other: &MinHashSignature) -> f64 {
115        if self.values.len() != other.values.len() {
116            return 0.0;
117        }
118
119        let matching = self
120            .values
121            .iter()
122            .zip(other.values.iter())
123            .filter(|(a, b)| a == b)
124            .count();
125
126        matching as f64 / self.values.len() as f64
127    }
128}
129
130/// LSH index for efficient approximate nearest neighbor search.
131pub struct LshIndex {
132    /// Configuration
133    config: LshConfig,
134    /// MinHash signatures for each component
135    signatures: HashMap<CanonicalId, MinHashSignature>,
136    /// Band buckets: band_index -> bucket_hash -> component IDs
137    buckets: Vec<HashMap<u64, Vec<CanonicalId>>>,
138    /// Hash coefficients for MinHash (a, b pairs for h(x) = (ax + b) mod p)
139    hash_coeffs: Vec<(u64, u64)>,
140    /// Large prime for hashing
141    prime: u64,
142}
143
144impl LshIndex {
145    /// Create a new LSH index with the given configuration.
146    pub fn new(config: LshConfig) -> Self {
147        use std::collections::hash_map::RandomState;
148        use std::hash::BuildHasher;
149
150        // Generate random hash coefficients
151        let mut hash_coeffs = Vec::with_capacity(config.num_hashes);
152        let random_state = RandomState::new();
153
154        for i in 0..config.num_hashes {
155            let a = random_state.hash_one(i as u64 * 31337) | 1; // Ensure odd (coprime with 2^64)
156
157            let b = random_state.hash_one(i as u64 * 7919 + 12345);
158
159            hash_coeffs.push((a, b));
160        }
161
162        // Initialize empty buckets for each band
163        let buckets = (0..config.num_bands)
164            .map(|_| HashMap::with_capacity(64))
165            .collect();
166
167        Self {
168            config,
169            signatures: HashMap::with_capacity(256),
170            buckets,
171            hash_coeffs,
172            prime: 0xFFFFFFFFFFFFFFC5, // Large prime close to 2^64
173        }
174    }
175
176    /// Build an LSH index from an SBOM.
177    pub fn build(sbom: &NormalizedSbom, config: LshConfig) -> Self {
178        let mut index = Self::new(config);
179
180        for (id, comp) in &sbom.components {
181            index.insert(id.clone(), comp);
182        }
183
184        index
185    }
186
187    /// Insert a component into the index.
188    pub fn insert(&mut self, id: CanonicalId, component: &Component) {
189        // Compute shingles from the component (uses ecosystem-aware normalization)
190        let shingles = self.compute_shingles(component);
191
192        // Compute MinHash signature
193        let signature = self.compute_minhash(&shingles);
194
195        // Insert into band buckets
196        self.insert_into_buckets(&id, &signature);
197
198        // Store signature
199        self.signatures.insert(id, signature);
200    }
201
202    /// Find candidate matches for a component.
203    ///
204    /// Returns component IDs that are likely similar based on LSH buckets.
205    /// These candidates should be verified with exact similarity computation.
206    pub fn find_candidates(&self, component: &Component) -> Vec<CanonicalId> {
207        let shingles = self.compute_shingles(component);
208        let signature = self.compute_minhash(&shingles);
209
210        self.find_candidates_by_signature(&signature)
211    }
212
213    /// Find candidates using a pre-computed signature.
214    pub fn find_candidates_by_signature(&self, signature: &MinHashSignature) -> Vec<CanonicalId> {
215        let mut candidates = HashSet::new();
216        let rows_per_band = self.config.rows_per_band();
217
218        for (band_idx, bucket_map) in self.buckets.iter().enumerate() {
219            let band_hash = self.hash_band(signature, band_idx, rows_per_band);
220
221            if let Some(ids) = bucket_map.get(&band_hash) {
222                for id in ids {
223                    candidates.insert(id.clone());
224                }
225            }
226        }
227
228        candidates.into_iter().collect()
229    }
230
231    /// Find candidates for a component from another index.
232    ///
233    /// Useful for diffing: build index from new SBOM, query with old SBOM components.
234    pub fn find_candidates_for_id(&self, id: &CanonicalId) -> Vec<CanonicalId> {
235        if let Some(signature) = self.signatures.get(id) {
236            self.find_candidates_by_signature(signature)
237        } else {
238            Vec::new()
239        }
240    }
241
242    /// Get the MinHash signature for a component.
243    pub fn get_signature(&self, id: &CanonicalId) -> Option<&MinHashSignature> {
244        self.signatures.get(id)
245    }
246
247    /// Estimate similarity between two components in the index.
248    pub fn estimate_similarity(&self, id_a: &CanonicalId, id_b: &CanonicalId) -> Option<f64> {
249        let sig_a = self.signatures.get(id_a)?;
250        let sig_b = self.signatures.get(id_b)?;
251        Some(sig_a.estimated_similarity(sig_b))
252    }
253
254    /// Get statistics about the index.
255    pub fn stats(&self) -> LshIndexStats {
256        let total_components = self.signatures.len();
257        let total_buckets: usize = self.buckets.iter().map(|b| b.len()).sum();
258        let max_bucket_size = self
259            .buckets
260            .iter()
261            .flat_map(|b| b.values())
262            .map(|v| v.len())
263            .max()
264            .unwrap_or(0);
265        let avg_bucket_size = if total_buckets > 0 {
266            self.buckets
267                .iter()
268                .flat_map(|b| b.values())
269                .map(|v| v.len())
270                .sum::<usize>() as f64
271                / total_buckets as f64
272        } else {
273            0.0
274        };
275
276        LshIndexStats {
277            total_components,
278            num_bands: self.config.num_bands,
279            num_hashes: self.config.num_hashes,
280            total_buckets,
281            max_bucket_size,
282            avg_bucket_size,
283        }
284    }
285
286    /// Compute character shingles (n-grams) from a component.
287    ///
288    /// Uses ecosystem-aware normalization from ComponentIndex for consistent
289    /// shingling across PyPI, Cargo, npm, etc. Also adds optional ecosystem
290    /// and group tokens to improve candidate grouping.
291    fn compute_shingles(&self, component: &Component) -> HashSet<u64> {
292        // Get ecosystem for normalization
293        let ecosystem = component.ecosystem.as_ref().map(|e| e.to_string());
294        let ecosystem_str = ecosystem.as_deref();
295
296        // Use ComponentIndex's normalization for consistency
297        let normalized = ComponentIndex::normalize_name(&component.name, ecosystem_str);
298        let chars: Vec<char> = normalized.chars().collect();
299
300        // Estimate capacity: roughly (len - shingle_size + 1) shingles + 2 tokens
301        let estimated_shingles = chars.len().saturating_sub(self.config.shingle_size) + 3;
302        let mut shingles = HashSet::with_capacity(estimated_shingles);
303
304        // Compute name shingles
305        if chars.len() < self.config.shingle_size {
306            // For very short names, use the whole name as a shingle
307            let mut hasher = std::collections::hash_map::DefaultHasher::new();
308            normalized.hash(&mut hasher);
309            shingles.insert(hasher.finish());
310        } else {
311            // Hash character windows directly without allocating intermediate strings
312            for window in chars.windows(self.config.shingle_size) {
313                let mut hasher = std::collections::hash_map::DefaultHasher::new();
314                window.hash(&mut hasher);
315                shingles.insert(hasher.finish());
316            }
317        }
318
319        // Add ecosystem token (helps group components by ecosystem)
320        if self.config.include_ecosystem_token {
321            if let Some(ref eco) = ecosystem {
322                let mut hasher = std::collections::hash_map::DefaultHasher::new();
323                "__eco:".hash(&mut hasher);
324                eco.to_lowercase().hash(&mut hasher);
325                shingles.insert(hasher.finish());
326            }
327        }
328
329        // Add group/namespace token (useful for Maven group IDs, npm scopes)
330        if self.config.include_group_token {
331            if let Some(ref group) = component.group {
332                let mut hasher = std::collections::hash_map::DefaultHasher::new();
333                "__grp:".hash(&mut hasher);
334                group.to_lowercase().hash(&mut hasher);
335                shingles.insert(hasher.finish());
336            }
337        }
338
339        shingles
340    }
341
342    /// Compute MinHash signature from shingles.
343    fn compute_minhash(&self, shingles: &HashSet<u64>) -> MinHashSignature {
344        let mut min_hashes = vec![u64::MAX; self.config.num_hashes];
345
346        for &shingle in shingles {
347            for (i, &(a, b)) in self.hash_coeffs.iter().enumerate() {
348                // h_i(x) = (a*x + b) mod prime
349                let hash = a.wrapping_mul(shingle).wrapping_add(b) % self.prime;
350                if hash < min_hashes[i] {
351                    min_hashes[i] = hash;
352                }
353            }
354        }
355
356        MinHashSignature { values: min_hashes }
357    }
358
359    /// Insert a signature into band buckets.
360    fn insert_into_buckets(&mut self, id: &CanonicalId, signature: &MinHashSignature) {
361        let rows_per_band = self.config.rows_per_band();
362
363        // Pre-compute all band hashes to avoid borrow conflicts
364        let band_hashes: Vec<u64> = (0..self.config.num_bands)
365            .map(|band_idx| self.hash_band(signature, band_idx, rows_per_band))
366            .collect();
367
368        for (band_idx, bucket_map) in self.buckets.iter_mut().enumerate() {
369            bucket_map
370                .entry(band_hashes[band_idx])
371                .or_default()
372                .push(id.clone());
373        }
374    }
375
376    /// Hash a band of the signature.
377    fn hash_band(
378        &self,
379        signature: &MinHashSignature,
380        band_idx: usize,
381        rows_per_band: usize,
382    ) -> u64 {
383        let start = band_idx * rows_per_band;
384        let end = (start + rows_per_band).min(signature.values.len());
385
386        let mut hasher = std::collections::hash_map::DefaultHasher::new();
387        for &value in &signature.values[start..end] {
388            value.hash(&mut hasher);
389        }
390        hasher.finish()
391    }
392}
393
394/// Statistics about an LSH index.
395#[derive(Debug, Clone)]
396pub struct LshIndexStats {
397    /// Total number of indexed components
398    pub total_components: usize,
399    /// Number of bands
400    pub num_bands: usize,
401    /// Total number of hash functions
402    pub num_hashes: usize,
403    /// Total number of non-empty buckets
404    pub total_buckets: usize,
405    /// Maximum components in a single bucket
406    pub max_bucket_size: usize,
407    /// Average components per bucket
408    pub avg_bucket_size: f64,
409}
410
411impl std::fmt::Display for LshIndexStats {
412    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413        write!(
414            f,
415            "LSH Index: {} components, {} bands × {} hashes, {} buckets (max: {}, avg: {:.1})",
416            self.total_components,
417            self.num_bands,
418            self.num_hashes / self.num_bands,
419            self.total_buckets,
420            self.max_bucket_size,
421            self.avg_bucket_size
422        )
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use crate::model::DocumentMetadata;
430
431    fn make_component(name: &str) -> Component {
432        Component::new(name.to_string(), format!("id-{}", name))
433    }
434
435    #[test]
436    fn test_lsh_config_for_threshold() {
437        let config = LshConfig::for_threshold(0.8);
438        assert_eq!(config.num_hashes, 100);
439        assert!(config.num_bands > 0);
440        assert_eq!(config.num_hashes, config.num_bands * config.rows_per_band());
441    }
442
443    #[test]
444    fn test_minhash_signature_similarity() {
445        let sig_a = MinHashSignature {
446            values: vec![1, 2, 3, 4, 5],
447        };
448        let sig_b = MinHashSignature {
449            values: vec![1, 2, 3, 4, 5],
450        };
451        assert_eq!(sig_a.estimated_similarity(&sig_b), 1.0);
452
453        let sig_c = MinHashSignature {
454            values: vec![1, 2, 3, 6, 7],
455        };
456        assert!((sig_a.estimated_similarity(&sig_c) - 0.6).abs() < 0.01);
457    }
458
459    #[test]
460    fn test_lsh_index_build() {
461        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
462        sbom.add_component(make_component("lodash"));
463        sbom.add_component(make_component("lodash-es"));
464        sbom.add_component(make_component("underscore"));
465        sbom.add_component(make_component("react"));
466
467        let index = LshIndex::build(&sbom, LshConfig::default_balanced());
468        let stats = index.stats();
469
470        assert_eq!(stats.total_components, 4);
471        assert!(stats.total_buckets > 0);
472    }
473
474    #[test]
475    fn test_lsh_finds_similar_names() {
476        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
477        sbom.add_component(make_component("lodash"));
478        sbom.add_component(make_component("lodash-es"));
479        sbom.add_component(make_component("lodash-fp"));
480        sbom.add_component(make_component("react"));
481        sbom.add_component(make_component("angular"));
482
483        let index = LshIndex::build(&sbom, LshConfig::for_threshold(0.5));
484
485        // Query for similar to "lodash"
486        let query = make_component("lodash");
487        let candidates = index.find_candidates(&query);
488
489        // Should find lodash variants as candidates
490        // Note: LSH is probabilistic, so we check for likely outcomes
491        assert!(
492            !candidates.is_empty(),
493            "Should find at least some candidates"
494        );
495    }
496
497    #[test]
498    fn test_lsh_signature_estimation() {
499        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
500
501        let comp1 = make_component("lodash");
502        let comp2 = make_component("lodash-es");
503        let comp3 = make_component("completely-different-name");
504
505        let id1 = comp1.canonical_id.clone();
506        let id2 = comp2.canonical_id.clone();
507        let id3 = comp3.canonical_id.clone();
508
509        sbom.add_component(comp1);
510        sbom.add_component(comp2);
511        sbom.add_component(comp3);
512
513        let index = LshIndex::build(&sbom, LshConfig::default_balanced());
514
515        // Similar names should have higher estimated similarity
516        let sim_12 = index.estimate_similarity(&id1, &id2).unwrap();
517        let sim_13 = index.estimate_similarity(&id1, &id3).unwrap();
518
519        assert!(
520            sim_12 > sim_13,
521            "lodash vs lodash-es ({:.2}) should be more similar than lodash vs completely-different ({:.2})",
522            sim_12, sim_13
523        );
524    }
525
526    #[test]
527    fn test_lsh_index_stats() {
528        let config = LshConfig::for_threshold(0.8);
529        let index = LshIndex::new(config);
530
531        let stats = index.stats();
532        assert_eq!(stats.total_components, 0);
533        assert_eq!(stats.num_bands, 25);
534        assert_eq!(stats.num_hashes, 100);
535    }
536}