Skip to main content

sql_rs/vector/
hnsw.rs

1use super::{Distance, DistanceMetric};
2use serde::{Deserialize, Serialize};
3use std::cmp::Ordering;
4use std::collections::{BinaryHeap, HashSet};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7struct Node {
8    vector: Vec<f32>,
9    connections: Vec<Vec<usize>>,
10}
11
12#[derive(Debug, Clone)]
13struct HeapItem {
14    distance: f32,
15    id: usize,
16}
17
18impl PartialEq for HeapItem {
19    fn eq(&self, other: &Self) -> bool {
20        self.distance == other.distance
21    }
22}
23
24impl Eq for HeapItem {}
25
26impl PartialOrd for HeapItem {
27    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
28        other.distance.partial_cmp(&self.distance)
29    }
30}
31
32impl Ord for HeapItem {
33    fn cmp(&self, other: &Self) -> Ordering {
34        self.partial_cmp(other).unwrap_or(Ordering::Equal)
35    }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct HnswIndex {
40    nodes: Vec<Node>,
41    max_layers: usize,
42    m: usize,
43    ef_construction: usize,
44    metric: DistanceMetric,
45}
46
47impl HnswIndex {
48    pub fn new(metric: DistanceMetric) -> Self {
49        Self {
50            nodes: Vec::new(),
51            max_layers: 4,
52            m: 16,
53            ef_construction: 200,
54            metric,
55        }
56    }
57
58    pub fn add(&mut self, vector: Vec<f32>) -> usize {
59        let id = self.nodes.len();
60        let layers = if self.nodes.is_empty() {
61            0
62        } else {
63            self.random_layer()
64        };
65
66        let mut connections = vec![Vec::new(); layers + 1];
67
68        if !self.nodes.is_empty() {
69            let entry_point = 0;
70            let top_layer = self.nodes[entry_point].connections.len().saturating_sub(1);
71
72            let mut nearest = vec![HeapItem {
73                distance: vector.distance(&self.nodes[entry_point].vector, self.metric),
74                id: entry_point,
75            }];
76
77            for layer in ((layers + 1)..=top_layer).rev() {
78                nearest = self.search_layer(&vector, nearest[0].id, 1, layer);
79            }
80
81            for layer in (0..=layers.min(top_layer)).rev() {
82                let candidates =
83                    self.search_layer(&vector, nearest[0].id, self.ef_construction, layer);
84
85                let m = if layer == 0 { self.m * 2 } else { self.m };
86                connections[layer] = candidates.iter().take(m).map(|item| item.id).collect();
87
88                for &neighbor_id in &connections[layer] {
89                    if neighbor_id < self.nodes.len()
90                        && layer < self.nodes[neighbor_id].connections.len()
91                    {
92                        self.nodes[neighbor_id].connections[layer].push(id);
93
94                        if self.nodes[neighbor_id].connections[layer].len() > m {
95                            self.prune_connections(neighbor_id, layer, m);
96                        }
97                    }
98                }
99
100                if !candidates.is_empty() {
101                    nearest = candidates;
102                }
103            }
104        }
105
106        self.nodes.push(Node {
107            vector,
108            connections,
109        });
110
111        id
112    }
113
114    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
115        if self.nodes.is_empty() {
116            return Vec::new();
117        }
118
119        let entry_point = 0;
120        let top_layer = self.nodes[entry_point].connections.len() - 1;
121
122        let mut nearest = vec![HeapItem {
123            distance: query.distance(&self.nodes[entry_point].vector, self.metric),
124            id: entry_point,
125        }];
126
127        for layer in (1..=top_layer).rev() {
128            nearest = self.search_layer(query, nearest[0].id, 1, layer);
129        }
130
131        let results = self.search_layer(query, nearest[0].id, k.max(self.ef_construction), 0);
132
133        results
134            .into_iter()
135            .take(k)
136            .map(|item| (item.id, item.distance))
137            .collect()
138    }
139
140    fn search_layer(
141        &self,
142        query: &[f32],
143        entry_point: usize,
144        ef: usize,
145        layer: usize,
146    ) -> Vec<HeapItem> {
147        let mut visited = HashSet::new();
148        let mut candidates = BinaryHeap::new();
149        let mut results = BinaryHeap::new();
150
151        let distance = query.distance(&self.nodes[entry_point].vector, self.metric);
152        let item = HeapItem {
153            distance,
154            id: entry_point,
155        };
156
157        candidates.push(item.clone());
158        results.push(item);
159        visited.insert(entry_point);
160
161        while let Some(current) = candidates.pop() {
162            if results.len() >= ef {
163                if let Some(worst) = results.peek() {
164                    if current.distance > worst.distance {
165                        break;
166                    }
167                }
168            }
169
170            if current.id < self.nodes.len() && layer < self.nodes[current.id].connections.len() {
171                for &neighbor_id in &self.nodes[current.id].connections[layer] {
172                    if neighbor_id < self.nodes.len() && visited.insert(neighbor_id) {
173                        let distance = query.distance(&self.nodes[neighbor_id].vector, self.metric);
174                        let item = HeapItem {
175                            distance,
176                            id: neighbor_id,
177                        };
178
179                        if results.len() < ef {
180                            candidates.push(item.clone());
181                            results.push(item);
182                        } else if let Some(worst) = results.peek() {
183                            if distance < worst.distance {
184                                candidates.push(item.clone());
185                                results.pop();
186                                results.push(item);
187                            }
188                        }
189                    }
190                }
191            }
192        }
193
194        let mut sorted: Vec<_> = results.into_iter().collect();
195        sorted.sort_by(|a, b| {
196            a.distance
197                .partial_cmp(&b.distance)
198                .unwrap_or(Ordering::Equal)
199        });
200        sorted
201    }
202
203    fn prune_connections(&mut self, node_id: usize, layer: usize, m: usize) {
204        if node_id >= self.nodes.len() || layer >= self.nodes[node_id].connections.len() {
205            return;
206        }
207
208        let node_vector = self.nodes[node_id].vector.clone();
209        let connections = &self.nodes[node_id].connections[layer];
210
211        let mut distances: Vec<_> = connections
212            .iter()
213            .filter(|&&id| id < self.nodes.len())
214            .map(|&id| {
215                let dist = node_vector.distance(&self.nodes[id].vector, self.metric);
216                (id, dist)
217            })
218            .collect();
219
220        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
221
222        self.nodes[node_id].connections[layer] =
223            distances.into_iter().take(m).map(|(id, _)| id).collect();
224    }
225
226    fn random_layer(&self) -> usize {
227        let mut layer = 0;
228        while layer < self.max_layers && rand::random::<f32>() < 0.5 {
229            layer += 1;
230        }
231        layer
232    }
233
234    pub fn len(&self) -> usize {
235        self.nodes.len()
236    }
237
238    pub fn is_empty(&self) -> bool {
239        self.nodes.is_empty()
240    }
241}
242
243mod rand {
244    use std::cell::Cell;
245
246    thread_local! {
247        static RNG: Cell<u64> = Cell::new(0x123456789abcdef0);
248    }
249
250    pub fn random<T: From<f32>>() -> T {
251        RNG.with(|rng| {
252            let mut x = rng.get();
253            x ^= x << 13;
254            x ^= x >> 7;
255            x ^= x << 17;
256            rng.set(x);
257            T::from((x as f32) / (u64::MAX as f32))
258        })
259    }
260}