1use rand::Rng;
2use serde::{Deserialize, Serialize};
3use std::cmp::Ordering;
4use std::collections::{BinaryHeap, HashMap, HashSet};
5
6#[derive(Debug, Clone, Copy, PartialEq)]
7pub struct OrderedFloat(pub f32);
8
9impl Eq for OrderedFloat {}
10
11impl PartialOrd for OrderedFloat {
12 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
13 Some(self.cmp(other))
14 }
15}
16
17impl Ord for OrderedFloat {
18 fn cmp(&self, other: &Self) -> Ordering {
19 self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
20 }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct HnswConfig {
25 pub m: usize,
26 pub m_max: usize,
27 pub m_max0: usize,
28 pub ef_construction: usize,
29 pub ef_search: usize,
30 pub ml: f32,
31}
32
33impl Default for HnswConfig {
34 fn default() -> Self {
35 Self {
36 m: 16,
37 m_max: 16,
38 m_max0: 32,
39 ef_construction: 200,
40 ef_search: 64,
41 ml: 1.0 / 16.0f32.ln(),
42 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct HnswIndex {
48 pub config: HnswConfig,
49 pub entry_point: Option<usize>,
50 pub max_layer: usize,
51 pub graph: HashMap<usize, Vec<Vec<usize>>>,
53}
54
55impl HnswIndex {
56 pub fn new(config: HnswConfig) -> Self {
57 Self {
58 config,
59 entry_point: None,
60 max_layer: 0,
61 graph: HashMap::new(),
62 }
63 }
64
65 fn random_layer(&self) -> usize {
66 let mut rng = rand::thread_rng();
67 let unif: f32 = rng.gen_range(0.0..1.0);
68 (-unif.ln() * self.config.ml).floor() as usize
69 }
70
71 pub fn insert<'a>(
72 &mut self,
73 node_id: usize,
74 vector: &[f32],
75 get_vector: &impl Fn(usize) -> &'a [f32],
76 distance_fn: &impl Fn(&[f32], &[f32]) -> f32,
77 ) {
78 let l = self.random_layer();
79
80 self.graph.insert(node_id, vec![vec![]; l + 1]);
82
83 let mut ep = if let Some(ep) = self.entry_point {
84 ep
85 } else {
86 self.entry_point = Some(node_id);
87 self.max_layer = l;
88 return;
89 };
90
91 let max_layer = self.max_layer;
92
93 let mut curr_node = ep;
94
95 for lc in (l + 1..=max_layer).rev() {
97 let mut curr_dist = distance_fn(vector, get_vector(curr_node));
98 let mut changed = true;
99 while changed {
100 changed = false;
101 if let Some(neighbors) = self.graph.get(&curr_node).and_then(|g| g.get(lc)) {
102 for &neighbor in neighbors {
103 let dist = distance_fn(vector, get_vector(neighbor));
104 if dist < curr_dist {
105 curr_dist = dist;
106 curr_node = neighbor;
107 changed = true;
108 }
109 }
110 }
111 }
112 }
113
114 ep = curr_node;
116 for lc in (0..=l.min(max_layer)).rev() {
117 let mut w = self.search_layer(
118 vector,
119 ep,
120 self.config.ef_construction,
121 lc,
122 get_vector,
123 distance_fn,
124 );
125 let neighbors = self.select_neighbors(&mut w, self.config.m);
126
127 for &neighbor in &neighbors {
128 self.graph.get_mut(&node_id).unwrap()[lc].push(neighbor);
129 self.graph.get_mut(&neighbor).unwrap()[lc].push(node_id);
130
131 let m_max = if lc == 0 {
132 self.config.m_max0
133 } else {
134 self.config.m_max
135 };
136 let neighbor_conns = &mut self.graph.get_mut(&neighbor).unwrap()[lc];
137 if neighbor_conns.len() > m_max {
138 let mut e_conn = BinaryHeap::new();
140 for &n2 in neighbor_conns.iter() {
141 let d = distance_fn(get_vector(neighbor), get_vector(n2));
142 e_conn.push((OrderedFloat(-d), n2));
146 }
147 let mut new_conns = Vec::new();
148 while let Some((_, n)) = e_conn.pop() {
149 new_conns.push(n);
150 if new_conns.len() == m_max {
151 break;
152 }
153 }
154 *neighbor_conns = new_conns;
155 }
156 }
157 ep = w
158 .iter()
159 .min_by_key(|(d, _)| OrderedFloat(*d))
160 .map(|(_, id)| *id)
161 .unwrap_or(ep);
162 }
163
164 if l > max_layer {
165 self.max_layer = l;
166 self.entry_point = Some(node_id);
167 }
168 }
169
170 pub fn search<'a>(
171 &self,
172 query: &[f32],
173 k: usize,
174 ef_search: usize,
175 get_vector: &impl Fn(usize) -> &'a [f32],
176 distance_fn: &impl Fn(&[f32], &[f32]) -> f32,
177 ) -> Vec<(usize, f32)> {
178 let mut ep = if let Some(ep) = self.entry_point {
179 ep
180 } else {
181 return Vec::new();
182 };
183
184 let max_layer = self.max_layer;
185
186 for lc in (1..=max_layer).rev() {
187 let mut curr_dist = distance_fn(query, get_vector(ep));
188 let mut changed = true;
189 while changed {
190 changed = false;
191 if let Some(neighbors) = self.graph.get(&ep).and_then(|g| g.get(lc)) {
192 for &neighbor in neighbors {
193 let dist = distance_fn(query, get_vector(neighbor));
194 if dist < curr_dist {
195 curr_dist = dist;
196 ep = neighbor;
197 changed = true;
198 }
199 }
200 }
201 }
202 }
203
204 let w = self.search_layer(query, ep, ef_search.max(k), 0, get_vector, distance_fn);
205
206 let mut res = w.into_iter().collect::<Vec<_>>();
207 res.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
208 res.truncate(k);
209 res.into_iter().map(|(d, id)| (id, d)).collect()
210 }
211
212 fn search_layer<'a>(
213 &self,
214 query: &[f32],
215 ep: usize,
216 ef: usize,
217 lc: usize,
218 get_vector: &impl Fn(usize) -> &'a [f32],
219 distance_fn: &impl Fn(&[f32], &[f32]) -> f32,
220 ) -> Vec<(f32, usize)> {
221 let mut v = HashSet::new();
222 let mut c = BinaryHeap::new(); let mut w = BinaryHeap::new(); let d = distance_fn(query, get_vector(ep));
226 v.insert(ep);
227 c.push((OrderedFloat(-d), ep));
228 w.push((OrderedFloat(d), ep));
229
230 while let Some((OrderedFloat(neg_c_dist), c_id)) = c.pop() {
231 let c_dist = -neg_c_dist;
232 let f_dist = w.peek().unwrap().0 .0;
233 if c_dist > f_dist {
234 break;
235 }
236
237 if let Some(neighbors) = self.graph.get(&c_id).and_then(|g| g.get(lc)) {
238 for &e in neighbors {
239 if !v.contains(&e) {
240 v.insert(e);
241 let f_dist = w.peek().unwrap().0 .0;
242 let e_dist = distance_fn(query, get_vector(e));
243
244 if e_dist < f_dist || w.len() < ef {
245 c.push((OrderedFloat(-e_dist), e));
246 w.push((OrderedFloat(e_dist), e));
247 if w.len() > ef {
248 w.pop();
249 }
250 }
251 }
252 }
253 }
254 }
255
256 w.into_iter().map(|(OrderedFloat(d), id)| (d, id)).collect()
257 }
258
259 fn select_neighbors(&self, candidates: &mut [(f32, usize)], m: usize) -> Vec<usize> {
260 candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
261 candidates.iter().take(m).map(|(_, id)| *id).collect()
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 #[test]
270 fn test_hnsw() {
271 let config = HnswConfig::default();
272 let mut index = HnswIndex::new(config);
273
274 let vectors = vec![
275 vec![1.0, 0.0],
276 vec![0.0, 1.0],
277 vec![1.0, 1.0],
278 vec![-1.0, 0.0],
279 ];
280
281 let get_vector = |id: usize| vectors[id].as_slice();
282 let distance_fn = |a: &[f32], b: &[f32]| {
283 a.iter()
284 .zip(b.iter())
285 .map(|(x, y)| (x - y).powi(2))
286 .sum::<f32>()
287 .sqrt()
288 };
289
290 for (id, vec) in vectors.iter().enumerate() {
291 index.insert(id, vec, &get_vector, &distance_fn);
292 }
293
294 let res = index.search(&[1.0, 0.1], 2, 10, &get_vector, &distance_fn);
295 assert_eq!(res.len(), 2);
296 assert_eq!(res[0].0, 0); }
298}