Skip to main content

veclite_index/
hnsw.rs

1use rand::Rng;
2use serde::{Deserialize, Serialize};
3use std::cmp::Ordering;
4use std::collections::{BinaryHeap, HashMap, HashSet};
5
6#[derive(Debug, Clone, Copy, PartialEq)]
7pub struct OrderedFloat(pub f32);
8
9impl Eq for OrderedFloat {}
10
11impl PartialOrd for OrderedFloat {
12    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
13        Some(self.cmp(other))
14    }
15}
16
17impl Ord for OrderedFloat {
18    fn cmp(&self, other: &Self) -> Ordering {
19        self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
20    }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct HnswConfig {
25    pub m: usize,
26    pub m_max: usize,
27    pub m_max0: usize,
28    pub ef_construction: usize,
29    pub ef_search: usize,
30    pub ml: f32,
31}
32
33impl Default for HnswConfig {
34    fn default() -> Self {
35        Self {
36            m: 16,
37            m_max: 16,
38            m_max0: 32,
39            ef_construction: 200,
40            ef_search: 64,
41            ml: 1.0 / 16.0f32.ln(),
42        }
43    }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct HnswIndex {
48    pub config: HnswConfig,
49    pub entry_point: Option<usize>,
50    pub max_layer: usize,
51    // Node ID -> [Layer 0 connections, Layer 1 connections, ...]
52    pub graph: HashMap<usize, Vec<Vec<usize>>>,
53}
54
55impl HnswIndex {
56    pub fn new(config: HnswConfig) -> Self {
57        Self {
58            config,
59            entry_point: None,
60            max_layer: 0,
61            graph: HashMap::new(),
62        }
63    }
64
65    fn random_layer(&self) -> usize {
66        let mut rng = rand::thread_rng();
67        let unif: f32 = rng.gen_range(0.0..1.0);
68        (-unif.ln() * self.config.ml).floor() as usize
69    }
70
71    pub fn insert<'a>(
72        &mut self,
73        node_id: usize,
74        vector: &[f32],
75        get_vector: &impl Fn(usize) -> &'a [f32],
76        distance_fn: &impl Fn(&[f32], &[f32]) -> f32,
77    ) {
78        let l = self.random_layer();
79
80        // Initialize layers for the new node
81        self.graph.insert(node_id, vec![vec![]; l + 1]);
82
83        let mut ep = if let Some(ep) = self.entry_point {
84            ep
85        } else {
86            self.entry_point = Some(node_id);
87            self.max_layer = l;
88            return;
89        };
90
91        let max_layer = self.max_layer;
92
93        let mut curr_node = ep;
94
95        // Phase 1: Traverse from top layer to l+1
96        for lc in (l + 1..=max_layer).rev() {
97            let mut curr_dist = distance_fn(vector, get_vector(curr_node));
98            let mut changed = true;
99            while changed {
100                changed = false;
101                if let Some(neighbors) = self.graph.get(&curr_node).and_then(|g| g.get(lc)) {
102                    for &neighbor in neighbors {
103                        let dist = distance_fn(vector, get_vector(neighbor));
104                        if dist < curr_dist {
105                            curr_dist = dist;
106                            curr_node = neighbor;
107                            changed = true;
108                        }
109                    }
110                }
111            }
112        }
113
114        // Phase 2: Insert and connect at layers l down to 0
115        ep = curr_node;
116        for lc in (0..=l.min(max_layer)).rev() {
117            let mut w = self.search_layer(
118                vector,
119                ep,
120                self.config.ef_construction,
121                lc,
122                get_vector,
123                distance_fn,
124            );
125            let neighbors = self.select_neighbors(&mut w, self.config.m);
126
127            for &neighbor in &neighbors {
128                self.graph.get_mut(&node_id).unwrap()[lc].push(neighbor);
129                self.graph.get_mut(&neighbor).unwrap()[lc].push(node_id);
130
131                let m_max = if lc == 0 {
132                    self.config.m_max0
133                } else {
134                    self.config.m_max
135                };
136                let neighbor_conns = &mut self.graph.get_mut(&neighbor).unwrap()[lc];
137                if neighbor_conns.len() > m_max {
138                    // Shrink neighbor connections
139                    let mut e_conn = BinaryHeap::new();
140                    for &n2 in neighbor_conns.iter() {
141                        let d = distance_fn(get_vector(neighbor), get_vector(n2));
142                        // max heap, we want to keep closest, so push with dist.
143                        // To keep smallest distances, we need to sort and take top, or use a min heap.
144                        // wait, standard binary heap is max heap. If we want smallest distance, we can put standard dist to find the largest to drop, or just sort.
145                        e_conn.push((OrderedFloat(-d), n2));
146                    }
147                    let mut new_conns = Vec::new();
148                    while let Some((_, n)) = e_conn.pop() {
149                        new_conns.push(n);
150                        if new_conns.len() == m_max {
151                            break;
152                        }
153                    }
154                    *neighbor_conns = new_conns;
155                }
156            }
157            ep = w
158                .iter()
159                .min_by_key(|(d, _)| OrderedFloat(*d))
160                .map(|(_, id)| *id)
161                .unwrap_or(ep);
162        }
163
164        if l > max_layer {
165            self.max_layer = l;
166            self.entry_point = Some(node_id);
167        }
168    }
169
170    pub fn search<'a>(
171        &self,
172        query: &[f32],
173        k: usize,
174        ef_search: usize,
175        get_vector: &impl Fn(usize) -> &'a [f32],
176        distance_fn: &impl Fn(&[f32], &[f32]) -> f32,
177    ) -> Vec<(usize, f32)> {
178        let mut ep = if let Some(ep) = self.entry_point {
179            ep
180        } else {
181            return Vec::new();
182        };
183
184        let max_layer = self.max_layer;
185
186        for lc in (1..=max_layer).rev() {
187            let mut curr_dist = distance_fn(query, get_vector(ep));
188            let mut changed = true;
189            while changed {
190                changed = false;
191                if let Some(neighbors) = self.graph.get(&ep).and_then(|g| g.get(lc)) {
192                    for &neighbor in neighbors {
193                        let dist = distance_fn(query, get_vector(neighbor));
194                        if dist < curr_dist {
195                            curr_dist = dist;
196                            ep = neighbor;
197                            changed = true;
198                        }
199                    }
200                }
201            }
202        }
203
204        let w = self.search_layer(query, ep, ef_search.max(k), 0, get_vector, distance_fn);
205
206        let mut res = w.into_iter().collect::<Vec<_>>();
207        res.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
208        res.truncate(k);
209        res.into_iter().map(|(d, id)| (id, d)).collect()
210    }
211
212    fn search_layer<'a>(
213        &self,
214        query: &[f32],
215        ep: usize,
216        ef: usize,
217        lc: usize,
218        get_vector: &impl Fn(usize) -> &'a [f32],
219        distance_fn: &impl Fn(&[f32], &[f32]) -> f32,
220    ) -> Vec<(f32, usize)> {
221        let mut v = HashSet::new();
222        let mut c = BinaryHeap::new(); // Candidates to explore: min-heap (ordered by dist asc -> store -dist)
223        let mut w = BinaryHeap::new(); // Nearest neighbors found: max-heap (ordered by dist desc -> store dist)
224
225        let d = distance_fn(query, get_vector(ep));
226        v.insert(ep);
227        c.push((OrderedFloat(-d), ep));
228        w.push((OrderedFloat(d), ep));
229
230        while let Some((OrderedFloat(neg_c_dist), c_id)) = c.pop() {
231            let c_dist = -neg_c_dist;
232            let f_dist = w.peek().unwrap().0 .0;
233            if c_dist > f_dist {
234                break;
235            }
236
237            if let Some(neighbors) = self.graph.get(&c_id).and_then(|g| g.get(lc)) {
238                for &e in neighbors {
239                    if !v.contains(&e) {
240                        v.insert(e);
241                        let f_dist = w.peek().unwrap().0 .0;
242                        let e_dist = distance_fn(query, get_vector(e));
243
244                        if e_dist < f_dist || w.len() < ef {
245                            c.push((OrderedFloat(-e_dist), e));
246                            w.push((OrderedFloat(e_dist), e));
247                            if w.len() > ef {
248                                w.pop();
249                            }
250                        }
251                    }
252                }
253            }
254        }
255
256        w.into_iter().map(|(OrderedFloat(d), id)| (d, id)).collect()
257    }
258
259    fn select_neighbors(&self, candidates: &mut [(f32, usize)], m: usize) -> Vec<usize> {
260        candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
261        candidates.iter().take(m).map(|(_, id)| *id).collect()
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn test_hnsw() {
271        let config = HnswConfig::default();
272        let mut index = HnswIndex::new(config);
273
274        let vectors = vec![
275            vec![1.0, 0.0],
276            vec![0.0, 1.0],
277            vec![1.0, 1.0],
278            vec![-1.0, 0.0],
279        ];
280
281        let get_vector = |id: usize| vectors[id].as_slice();
282        let distance_fn = |a: &[f32], b: &[f32]| {
283            a.iter()
284                .zip(b.iter())
285                .map(|(x, y)| (x - y).powi(2))
286                .sum::<f32>()
287                .sqrt()
288        };
289
290        for (id, vec) in vectors.iter().enumerate() {
291            index.insert(id, vec, &get_vector, &distance_fn);
292        }
293
294        let res = index.search(&[1.0, 0.1], 2, 10, &get_vector, &distance_fn);
295        assert_eq!(res.len(), 2);
296        assert_eq!(res[0].0, 0); // closest is [1.0, 0.0]
297    }
298}