1use super::{Distance, DistanceMetric};
2use serde::{Deserialize, Serialize};
3use std::cmp::Ordering;
4use std::collections::{BinaryHeap, HashSet};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7struct Node {
8 vector: Vec<f32>,
9 connections: Vec<Vec<usize>>,
10}
11
12#[derive(Debug, Clone)]
13struct HeapItem {
14 distance: f32,
15 id: usize,
16}
17
18impl PartialEq for HeapItem {
19 fn eq(&self, other: &Self) -> bool {
20 self.distance == other.distance
21 }
22}
23
24impl Eq for HeapItem {}
25
26impl PartialOrd for HeapItem {
27 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
28 other.distance.partial_cmp(&self.distance)
29 }
30}
31
32impl Ord for HeapItem {
33 fn cmp(&self, other: &Self) -> Ordering {
34 self.partial_cmp(other).unwrap_or(Ordering::Equal)
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct HnswIndex {
40 nodes: Vec<Node>,
41 max_layers: usize,
42 m: usize,
43 ef_construction: usize,
44 metric: DistanceMetric,
45}
46
47impl HnswIndex {
48 pub fn new(metric: DistanceMetric) -> Self {
49 Self {
50 nodes: Vec::new(),
51 max_layers: 4,
52 m: 16,
53 ef_construction: 200,
54 metric,
55 }
56 }
57
58 pub fn add(&mut self, vector: Vec<f32>) -> usize {
59 let id = self.nodes.len();
60 let layers = if self.nodes.is_empty() {
61 0
62 } else {
63 self.random_layer()
64 };
65
66 let mut connections = vec![Vec::new(); layers + 1];
67
68 if !self.nodes.is_empty() {
69 let entry_point = 0;
70 let top_layer = self.nodes[entry_point].connections.len().saturating_sub(1);
71
72 let mut nearest = vec![HeapItem {
73 distance: vector.distance(&self.nodes[entry_point].vector, self.metric),
74 id: entry_point,
75 }];
76
77 for layer in ((layers + 1)..=top_layer).rev() {
78 nearest = self.search_layer(&vector, nearest[0].id, 1, layer);
79 }
80
81 for layer in (0..=layers.min(top_layer)).rev() {
82 let candidates =
83 self.search_layer(&vector, nearest[0].id, self.ef_construction, layer);
84
85 let m = if layer == 0 { self.m * 2 } else { self.m };
86 connections[layer] = candidates.iter().take(m).map(|item| item.id).collect();
87
88 for &neighbor_id in &connections[layer] {
89 if neighbor_id < self.nodes.len()
90 && layer < self.nodes[neighbor_id].connections.len()
91 {
92 self.nodes[neighbor_id].connections[layer].push(id);
93
94 if self.nodes[neighbor_id].connections[layer].len() > m {
95 self.prune_connections(neighbor_id, layer, m);
96 }
97 }
98 }
99
100 if !candidates.is_empty() {
101 nearest = candidates;
102 }
103 }
104 }
105
106 self.nodes.push(Node {
107 vector,
108 connections,
109 });
110
111 id
112 }
113
114 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
115 if self.nodes.is_empty() {
116 return Vec::new();
117 }
118
119 let entry_point = 0;
120 let top_layer = self.nodes[entry_point].connections.len() - 1;
121
122 let mut nearest = vec![HeapItem {
123 distance: query.distance(&self.nodes[entry_point].vector, self.metric),
124 id: entry_point,
125 }];
126
127 for layer in (1..=top_layer).rev() {
128 nearest = self.search_layer(query, nearest[0].id, 1, layer);
129 }
130
131 let results = self.search_layer(query, nearest[0].id, k.max(self.ef_construction), 0);
132
133 results
134 .into_iter()
135 .take(k)
136 .map(|item| (item.id, item.distance))
137 .collect()
138 }
139
140 fn search_layer(
141 &self,
142 query: &[f32],
143 entry_point: usize,
144 ef: usize,
145 layer: usize,
146 ) -> Vec<HeapItem> {
147 let mut visited = HashSet::new();
148 let mut candidates = BinaryHeap::new();
149 let mut results = BinaryHeap::new();
150
151 let distance = query.distance(&self.nodes[entry_point].vector, self.metric);
152 let item = HeapItem {
153 distance,
154 id: entry_point,
155 };
156
157 candidates.push(item.clone());
158 results.push(item);
159 visited.insert(entry_point);
160
161 while let Some(current) = candidates.pop() {
162 if results.len() >= ef {
163 if let Some(worst) = results.peek() {
164 if current.distance > worst.distance {
165 break;
166 }
167 }
168 }
169
170 if current.id < self.nodes.len() && layer < self.nodes[current.id].connections.len() {
171 for &neighbor_id in &self.nodes[current.id].connections[layer] {
172 if neighbor_id < self.nodes.len() && visited.insert(neighbor_id) {
173 let distance = query.distance(&self.nodes[neighbor_id].vector, self.metric);
174 let item = HeapItem {
175 distance,
176 id: neighbor_id,
177 };
178
179 if results.len() < ef {
180 candidates.push(item.clone());
181 results.push(item);
182 } else if let Some(worst) = results.peek() {
183 if distance < worst.distance {
184 candidates.push(item.clone());
185 results.pop();
186 results.push(item);
187 }
188 }
189 }
190 }
191 }
192 }
193
194 let mut sorted: Vec<_> = results.into_iter().collect();
195 sorted.sort_by(|a, b| {
196 a.distance
197 .partial_cmp(&b.distance)
198 .unwrap_or(Ordering::Equal)
199 });
200 sorted
201 }
202
203 fn prune_connections(&mut self, node_id: usize, layer: usize, m: usize) {
204 if node_id >= self.nodes.len() || layer >= self.nodes[node_id].connections.len() {
205 return;
206 }
207
208 let node_vector = self.nodes[node_id].vector.clone();
209 let connections = &self.nodes[node_id].connections[layer];
210
211 let mut distances: Vec<_> = connections
212 .iter()
213 .filter(|&&id| id < self.nodes.len())
214 .map(|&id| {
215 let dist = node_vector.distance(&self.nodes[id].vector, self.metric);
216 (id, dist)
217 })
218 .collect();
219
220 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
221
222 self.nodes[node_id].connections[layer] =
223 distances.into_iter().take(m).map(|(id, _)| id).collect();
224 }
225
226 fn random_layer(&self) -> usize {
227 let mut layer = 0;
228 while layer < self.max_layers && rand::random::<f32>() < 0.5 {
229 layer += 1;
230 }
231 layer
232 }
233
234 pub fn len(&self) -> usize {
235 self.nodes.len()
236 }
237
238 pub fn is_empty(&self) -> bool {
239 self.nodes.is_empty()
240 }
241}
242
243mod rand {
244 use std::cell::Cell;
245
246 thread_local! {
247 static RNG: Cell<u64> = Cell::new(0x123456789abcdef0);
248 }
249
250 pub fn random<T: From<f32>>() -> T {
251 RNG.with(|rng| {
252 let mut x = rng.get();
253 x ^= x << 13;
254 x ^= x >> 7;
255 x ^= x << 17;
256 rng.set(x);
257 T::from((x as f32) / (u64::MAX as f32))
258 })
259 }
260}