Skip to main content

sql_rs/vector/
hnsw.rs

1use super::{Distance, DistanceMetric};
2use std::collections::{BinaryHeap, HashSet};
3use std::cmp::Ordering;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7struct Node {
8    vector: Vec<f32>,
9    connections: Vec<Vec<usize>>,
10}
11
12#[derive(Debug, Clone)]
13struct HeapItem {
14    distance: f32,
15    id: usize,
16}
17
18impl PartialEq for HeapItem {
19    fn eq(&self, other: &Self) -> bool {
20        self.distance == other.distance
21    }
22}
23
24impl Eq for HeapItem {}
25
26impl PartialOrd for HeapItem {
27    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
28        other.distance.partial_cmp(&self.distance)
29    }
30}
31
32impl Ord for HeapItem {
33    fn cmp(&self, other: &Self) -> Ordering {
34        self.partial_cmp(other).unwrap_or(Ordering::Equal)
35    }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct HnswIndex {
40    nodes: Vec<Node>,
41    max_layers: usize,
42    m: usize,
43    ef_construction: usize,
44    metric: DistanceMetric,
45}
46
47impl HnswIndex {
48    pub fn new(metric: DistanceMetric) -> Self {
49        Self {
50            nodes: Vec::new(),
51            max_layers: 4,
52            m: 16,
53            ef_construction: 200,
54            metric,
55        }
56    }
57    
58    pub fn add(&mut self, vector: Vec<f32>) -> usize {
59        let id = self.nodes.len();
60        let layers = if self.nodes.is_empty() { 0 } else { self.random_layer() };
61        
62        let mut connections = vec![Vec::new(); layers + 1];
63        
64        if !self.nodes.is_empty() {
65            let entry_point = 0;
66            let top_layer = self.nodes[entry_point].connections.len().saturating_sub(1);
67            
68            let mut nearest = vec![HeapItem {
69                distance: vector.distance(&self.nodes[entry_point].vector, self.metric),
70                id: entry_point,
71            }];
72            
73            for layer in ((layers + 1)..=top_layer).rev() {
74                nearest = self.search_layer(&vector, nearest[0].id, 1, layer);
75            }
76            
77            for layer in (0..=layers.min(top_layer)).rev() {
78                let candidates = self.search_layer(&vector, nearest[0].id, self.ef_construction, layer);
79                
80                let m = if layer == 0 { self.m * 2 } else { self.m };
81                connections[layer] = candidates.iter().take(m).map(|item| item.id).collect();
82                
83                for &neighbor_id in &connections[layer] {
84                    if neighbor_id < self.nodes.len() && layer < self.nodes[neighbor_id].connections.len() {
85                        self.nodes[neighbor_id].connections[layer].push(id);
86                        
87                        if self.nodes[neighbor_id].connections[layer].len() > m {
88                            self.prune_connections(neighbor_id, layer, m);
89                        }
90                    }
91                }
92                
93                if !candidates.is_empty() {
94                    nearest = candidates;
95                }
96            }
97        }
98        
99        self.nodes.push(Node {
100            vector,
101            connections,
102        });
103
104        id
105    }
106    
107    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
108        if self.nodes.is_empty() {
109            return Vec::new();
110        }
111        
112        let entry_point = 0;
113        let top_layer = self.nodes[entry_point].connections.len() - 1;
114        
115        let mut nearest = vec![HeapItem {
116            distance: query.distance(&self.nodes[entry_point].vector, self.metric),
117            id: entry_point,
118        }];
119        
120        for layer in (1..=top_layer).rev() {
121            nearest = self.search_layer(query, nearest[0].id, 1, layer);
122        }
123        
124        let results = self.search_layer(query, nearest[0].id, k.max(self.ef_construction), 0);
125        
126        results.into_iter()
127            .take(k)
128            .map(|item| (item.id, item.distance))
129            .collect()
130    }
131    
132    fn search_layer(&self, query: &[f32], entry_point: usize, ef: usize, layer: usize) -> Vec<HeapItem> {
133        let mut visited = HashSet::new();
134        let mut candidates = BinaryHeap::new();
135        let mut results = BinaryHeap::new();
136        
137        let distance = query.distance(&self.nodes[entry_point].vector, self.metric);
138        let item = HeapItem {
139            distance,
140            id: entry_point,
141        };
142        
143        candidates.push(item.clone());
144        results.push(item);
145        visited.insert(entry_point);
146        
147        while let Some(current) = candidates.pop() {
148            if results.len() >= ef {
149                if let Some(worst) = results.peek() {
150                    if current.distance > worst.distance {
151                        break;
152                    }
153                }
154            }
155            
156            if current.id < self.nodes.len() && layer < self.nodes[current.id].connections.len() {
157                for &neighbor_id in &self.nodes[current.id].connections[layer] {
158                    if neighbor_id < self.nodes.len() && visited.insert(neighbor_id) {
159                        let distance = query.distance(&self.nodes[neighbor_id].vector, self.metric);
160                        let item = HeapItem {
161                            distance,
162                            id: neighbor_id,
163                        };
164                        
165                        if results.len() < ef {
166                            candidates.push(item.clone());
167                            results.push(item);
168                        } else if let Some(worst) = results.peek() {
169                            if distance < worst.distance {
170                                candidates.push(item.clone());
171                                results.pop();
172                                results.push(item);
173                            }
174                        }
175                    }
176                }
177            }
178        }
179        
180        let mut sorted: Vec<_> = results.into_iter().collect();
181        sorted.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap_or(Ordering::Equal));
182        sorted
183    }
184    
185    fn prune_connections(&mut self, node_id: usize, layer: usize, m: usize) {
186        if node_id >= self.nodes.len() || layer >= self.nodes[node_id].connections.len() {
187            return;
188        }
189        
190        let node_vector = self.nodes[node_id].vector.clone();
191        let connections = &self.nodes[node_id].connections[layer];
192        
193        let mut distances: Vec<_> = connections
194            .iter()
195            .filter(|&&id| id < self.nodes.len())
196            .map(|&id| {
197                let dist = node_vector.distance(&self.nodes[id].vector, self.metric);
198                (id, dist)
199            })
200            .collect();
201        
202        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
203        
204        self.nodes[node_id].connections[layer] = distances
205            .into_iter()
206            .take(m)
207            .map(|(id, _)| id)
208            .collect();
209    }
210    
211    fn random_layer(&self) -> usize {
212        let mut layer = 0;
213        while layer < self.max_layers && rand::random::<f32>() < 0.5 {
214            layer += 1;
215        }
216        layer
217    }
218    
219    pub fn len(&self) -> usize {
220        self.nodes.len()
221    }
222    
223    pub fn is_empty(&self) -> bool {
224        self.nodes.is_empty()
225    }
226}
227
228mod rand {
229    use std::cell::Cell;
230    
231    thread_local! {
232        static RNG: Cell<u64> = Cell::new(0x123456789abcdef0);
233    }
234    
235    pub fn random<T: From<f32>>() -> T {
236        RNG.with(|rng| {
237            let mut x = rng.get();
238            x ^= x << 13;
239            x ^= x >> 7;
240            x ^= x << 17;
241            rng.set(x);
242            T::from((x as f32) / (u64::MAX as f32))
243        })
244    }
245}