1use crate::distance::DistanceMetric;
2
3#[derive(Debug, Clone)]
5pub struct SuggestedParams {
6 pub num_hashes: usize,
7 pub num_tables: usize,
8 pub num_probes: usize,
9 pub estimated_recall: f64,
10}
11
12pub fn suggest_params(
26 target_recall: f64,
27 dataset_size: usize,
28 _dim: usize,
29 metric: DistanceMetric,
30) -> SuggestedParams {
31 let target_recall = target_recall.clamp(0.5, 0.999);
32
33 let p_collision: f64 = match metric {
36 DistanceMetric::Cosine | DistanceMetric::DotProduct => 0.667,
37 DistanceMetric::Euclidean => 0.6,
38 };
39
40 let mut best = SuggestedParams {
41 num_hashes: 16,
42 num_tables: 8,
43 num_probes: 2,
44 estimated_recall: 0.0,
45 };
46 let mut best_cost = f64::MAX;
47
48 for k in 4..=32usize {
49 let p_table = p_collision.powi(k as i32);
50
51 let l_frac = (1.0 - target_recall).ln() / (1.0 - p_table).ln();
53 let l = (l_frac.ceil() as usize).clamp(1, 100);
54
55 let recall = 1.0 - (1.0 - p_table).powi(l as i32);
56
57 let cost = l as f64 * (1.0 + k as f64);
59
60 if recall >= target_recall && cost < best_cost {
61 best_cost = cost;
62 let probes = (k / 4).clamp(1, 8);
63 best = SuggestedParams {
64 num_hashes: k,
65 num_tables: l,
66 num_probes: probes,
67 estimated_recall: recall,
68 };
69 }
70 }
71
72 if dataset_size > 100_000 {
74 let scale = ((dataset_size as f64 / 100_000.0).ln() + 1.0).ceil() as usize;
75 best.num_tables = (best.num_tables * scale).min(50);
76 }
77
78 if target_recall > 0.95 {
79 best.num_probes = (best.num_probes * 2).min(best.num_hashes);
80 }
81
82 best
83}
84
85pub fn estimate_recall(
90 num_hashes: usize,
91 num_tables: usize,
92 num_probes: usize,
93 metric: DistanceMetric,
94) -> f64 {
95 let p_collision: f64 = match metric {
96 DistanceMetric::Cosine | DistanceMetric::DotProduct => 0.667,
97 DistanceMetric::Euclidean => 0.6,
98 };
99
100 let p_table = p_collision.powi(num_hashes as i32);
101
102 let p_probe_bonus = if num_hashes > 1 {
104 num_probes as f64
105 * p_collision.powi((num_hashes - 1) as i32)
106 * (1.0 - p_collision)
107 } else {
108 0.0
109 };
110 let p_effective = (p_table + p_probe_bonus).min(1.0);
111
112 1.0 - (1.0 - p_effective).powi(num_tables as i32)
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test_suggest_params_reasonable() {
121 let params = suggest_params(0.9, 100_000, 768, DistanceMetric::Cosine);
122 assert!(params.num_hashes >= 4);
123 assert!(params.num_hashes <= 32);
124 assert!(params.num_tables >= 1);
125 assert!(params.estimated_recall >= 0.9);
126 }
127
128 #[test]
129 fn test_higher_recall_needs_more_resources() {
130 let low = suggest_params(0.8, 10_000, 128, DistanceMetric::Cosine);
131 let high = suggest_params(0.95, 10_000, 128, DistanceMetric::Cosine);
132 assert!(
134 high.num_tables >= low.num_tables || high.num_probes >= low.num_probes,
135 "high recall params should use more resources: low={low:?} high={high:?}"
136 );
137 }
138
139 #[test]
140 fn test_estimate_recall_increases_with_tables() {
141 let r4 = estimate_recall(16, 4, 2, DistanceMetric::Cosine);
142 let r8 = estimate_recall(16, 8, 2, DistanceMetric::Cosine);
143 assert!(r8 > r4);
144 }
145}