Skip to main content

superbit/
tuning.rs

1use crate::distance::DistanceMetric;
2
3/// Suggested parameters for the LSH index, produced by auto-tuning.
4#[derive(Debug, Clone)]
5pub struct SuggestedParams {
6    pub num_hashes: usize,
7    pub num_tables: usize,
8    pub num_probes: usize,
9    pub estimated_recall: f64,
10}
11
12/// Suggest LSH parameters based on dataset characteristics and desired recall.
13///
14/// Uses the theoretical collision probability for random hyperplane LSH:
15///   P(collision per bit) = 1 - theta / pi
16///
17/// For K hash bits per table: P_table = P^K
18/// For L tables: P_total = 1 - (1 - P_table)^L
19///
20/// # Arguments
21/// * `target_recall` - Desired recall in [0.5, 0.999]
22/// * `dataset_size` - Expected number of vectors
23/// * `_dim` - Vector dimensionality (reserved for future heuristics)
24/// * `metric` - Distance metric being used
25pub fn suggest_params(
26    target_recall: f64,
27    dataset_size: usize,
28    _dim: usize,
29    metric: DistanceMetric,
30) -> SuggestedParams {
31    let target_recall = target_recall.clamp(0.5, 0.999);
32
33    // Assume average angle of ~60 degrees between relevant (nearby) pairs.
34    // P(sign match) = 1 - 60/180 = 0.667 for cosine/angular LSH.
35    let p_collision: f64 = match metric {
36        DistanceMetric::Cosine | DistanceMetric::DotProduct => 0.667,
37        DistanceMetric::Euclidean => 0.6,
38    };
39
40    let mut best = SuggestedParams {
41        num_hashes: 16,
42        num_tables: 8,
43        num_probes: 2,
44        estimated_recall: 0.0,
45    };
46    let mut best_cost = f64::MAX;
47
48    for k in 4..=32usize {
49        let p_table = p_collision.powi(k as i32);
50
51        // Minimum L tables so that 1 - (1 - p_table)^L >= target_recall
52        let l_frac = (1.0 - target_recall).ln() / (1.0 - p_table).ln();
53        let l = (l_frac.ceil() as usize).clamp(1, 100);
54
55        let recall = 1.0 - (1.0 - p_table).powi(l as i32);
56
57        // Cost heuristic balancing memory (proportional to L) and query time (L * K).
58        let cost = l as f64 * (1.0 + k as f64);
59
60        if recall >= target_recall && cost < best_cost {
61            best_cost = cost;
62            let probes = (k / 4).clamp(1, 8);
63            best = SuggestedParams {
64                num_hashes: k,
65                num_tables: l,
66                num_probes: probes,
67                estimated_recall: recall,
68            };
69        }
70    }
71
72    // Larger datasets benefit from more tables.
73    if dataset_size > 100_000 {
74        let scale = ((dataset_size as f64 / 100_000.0).ln() + 1.0).ceil() as usize;
75        best.num_tables = (best.num_tables * scale).min(50);
76    }
77
78    if target_recall > 0.95 {
79        best.num_probes = (best.num_probes * 2).min(best.num_hashes);
80    }
81
82    best
83}
84
85/// Estimate recall for a given set of LSH parameters.
86///
87/// Accounts for multi-probe by approximating the additional collision probability
88/// from flipping the most uncertain bits.
89pub fn estimate_recall(
90    num_hashes: usize,
91    num_tables: usize,
92    num_probes: usize,
93    metric: DistanceMetric,
94) -> f64 {
95    let p_collision: f64 = match metric {
96        DistanceMetric::Cosine | DistanceMetric::DotProduct => 0.667,
97        DistanceMetric::Euclidean => 0.6,
98    };
99
100    let p_table = p_collision.powi(num_hashes as i32);
101
102    // Each probe adds roughly p^(K-1) * (1-p) probability.
103    let p_probe_bonus = if num_hashes > 1 {
104        num_probes as f64
105            * p_collision.powi((num_hashes - 1) as i32)
106            * (1.0 - p_collision)
107    } else {
108        0.0
109    };
110    let p_effective = (p_table + p_probe_bonus).min(1.0);
111
112    1.0 - (1.0 - p_effective).powi(num_tables as i32)
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn test_suggest_params_reasonable() {
121        let params = suggest_params(0.9, 100_000, 768, DistanceMetric::Cosine);
122        assert!(params.num_hashes >= 4);
123        assert!(params.num_hashes <= 32);
124        assert!(params.num_tables >= 1);
125        assert!(params.estimated_recall >= 0.9);
126    }
127
128    #[test]
129    fn test_higher_recall_needs_more_resources() {
130        let low = suggest_params(0.8, 10_000, 128, DistanceMetric::Cosine);
131        let high = suggest_params(0.95, 10_000, 128, DistanceMetric::Cosine);
132        // Higher recall should need more tables or more probes.
133        assert!(
134            high.num_tables >= low.num_tables || high.num_probes >= low.num_probes,
135            "high recall params should use more resources: low={low:?} high={high:?}"
136        );
137    }
138
139    #[test]
140    fn test_estimate_recall_increases_with_tables() {
141        let r4 = estimate_recall(16, 4, 2, DistanceMetric::Cosine);
142        let r8 = estimate_recall(16, 8, 2, DistanceMetric::Cosine);
143        assert!(r8 > r4);
144    }
145}