1use super::{Distance, DistanceMetric};
2use std::collections::{BinaryHeap, HashSet};
3use std::cmp::Ordering;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7struct Node {
8 vector: Vec<f32>,
9 connections: Vec<Vec<usize>>,
10}
11
12#[derive(Debug, Clone)]
13struct HeapItem {
14 distance: f32,
15 id: usize,
16}
17
18impl PartialEq for HeapItem {
19 fn eq(&self, other: &Self) -> bool {
20 self.distance == other.distance
21 }
22}
23
24impl Eq for HeapItem {}
25
26impl PartialOrd for HeapItem {
27 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
28 other.distance.partial_cmp(&self.distance)
29 }
30}
31
32impl Ord for HeapItem {
33 fn cmp(&self, other: &Self) -> Ordering {
34 self.partial_cmp(other).unwrap_or(Ordering::Equal)
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct HnswIndex {
40 nodes: Vec<Node>,
41 max_layers: usize,
42 m: usize,
43 ef_construction: usize,
44 metric: DistanceMetric,
45}
46
47impl HnswIndex {
48 pub fn new(metric: DistanceMetric) -> Self {
49 Self {
50 nodes: Vec::new(),
51 max_layers: 4,
52 m: 16,
53 ef_construction: 200,
54 metric,
55 }
56 }
57
58 pub fn add(&mut self, vector: Vec<f32>) -> usize {
59 let id = self.nodes.len();
60 let layers = if self.nodes.is_empty() { 0 } else { self.random_layer() };
61
62 let mut connections = vec![Vec::new(); layers + 1];
63
64 if !self.nodes.is_empty() {
65 let entry_point = 0;
66 let top_layer = self.nodes[entry_point].connections.len().saturating_sub(1);
67
68 let mut nearest = vec![HeapItem {
69 distance: vector.distance(&self.nodes[entry_point].vector, self.metric),
70 id: entry_point,
71 }];
72
73 for layer in ((layers + 1)..=top_layer).rev() {
74 nearest = self.search_layer(&vector, nearest[0].id, 1, layer);
75 }
76
77 for layer in (0..=layers.min(top_layer)).rev() {
78 let candidates = self.search_layer(&vector, nearest[0].id, self.ef_construction, layer);
79
80 let m = if layer == 0 { self.m * 2 } else { self.m };
81 connections[layer] = candidates.iter().take(m).map(|item| item.id).collect();
82
83 for &neighbor_id in &connections[layer] {
84 if neighbor_id < self.nodes.len() && layer < self.nodes[neighbor_id].connections.len() {
85 self.nodes[neighbor_id].connections[layer].push(id);
86
87 if self.nodes[neighbor_id].connections[layer].len() > m {
88 self.prune_connections(neighbor_id, layer, m);
89 }
90 }
91 }
92
93 if !candidates.is_empty() {
94 nearest = candidates;
95 }
96 }
97 }
98
99 self.nodes.push(Node {
100 vector,
101 connections,
102 });
103
104 id
105 }
106
107 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
108 if self.nodes.is_empty() {
109 return Vec::new();
110 }
111
112 let entry_point = 0;
113 let top_layer = self.nodes[entry_point].connections.len() - 1;
114
115 let mut nearest = vec![HeapItem {
116 distance: query.distance(&self.nodes[entry_point].vector, self.metric),
117 id: entry_point,
118 }];
119
120 for layer in (1..=top_layer).rev() {
121 nearest = self.search_layer(query, nearest[0].id, 1, layer);
122 }
123
124 let results = self.search_layer(query, nearest[0].id, k.max(self.ef_construction), 0);
125
126 results.into_iter()
127 .take(k)
128 .map(|item| (item.id, item.distance))
129 .collect()
130 }
131
132 fn search_layer(&self, query: &[f32], entry_point: usize, ef: usize, layer: usize) -> Vec<HeapItem> {
133 let mut visited = HashSet::new();
134 let mut candidates = BinaryHeap::new();
135 let mut results = BinaryHeap::new();
136
137 let distance = query.distance(&self.nodes[entry_point].vector, self.metric);
138 let item = HeapItem {
139 distance,
140 id: entry_point,
141 };
142
143 candidates.push(item.clone());
144 results.push(item);
145 visited.insert(entry_point);
146
147 while let Some(current) = candidates.pop() {
148 if results.len() >= ef {
149 if let Some(worst) = results.peek() {
150 if current.distance > worst.distance {
151 break;
152 }
153 }
154 }
155
156 if current.id < self.nodes.len() && layer < self.nodes[current.id].connections.len() {
157 for &neighbor_id in &self.nodes[current.id].connections[layer] {
158 if neighbor_id < self.nodes.len() && visited.insert(neighbor_id) {
159 let distance = query.distance(&self.nodes[neighbor_id].vector, self.metric);
160 let item = HeapItem {
161 distance,
162 id: neighbor_id,
163 };
164
165 if results.len() < ef {
166 candidates.push(item.clone());
167 results.push(item);
168 } else if let Some(worst) = results.peek() {
169 if distance < worst.distance {
170 candidates.push(item.clone());
171 results.pop();
172 results.push(item);
173 }
174 }
175 }
176 }
177 }
178 }
179
180 let mut sorted: Vec<_> = results.into_iter().collect();
181 sorted.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap_or(Ordering::Equal));
182 sorted
183 }
184
185 fn prune_connections(&mut self, node_id: usize, layer: usize, m: usize) {
186 if node_id >= self.nodes.len() || layer >= self.nodes[node_id].connections.len() {
187 return;
188 }
189
190 let node_vector = self.nodes[node_id].vector.clone();
191 let connections = &self.nodes[node_id].connections[layer];
192
193 let mut distances: Vec<_> = connections
194 .iter()
195 .filter(|&&id| id < self.nodes.len())
196 .map(|&id| {
197 let dist = node_vector.distance(&self.nodes[id].vector, self.metric);
198 (id, dist)
199 })
200 .collect();
201
202 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
203
204 self.nodes[node_id].connections[layer] = distances
205 .into_iter()
206 .take(m)
207 .map(|(id, _)| id)
208 .collect();
209 }
210
211 fn random_layer(&self) -> usize {
212 let mut layer = 0;
213 while layer < self.max_layers && rand::random::<f32>() < 0.5 {
214 layer += 1;
215 }
216 layer
217 }
218
219 pub fn len(&self) -> usize {
220 self.nodes.len()
221 }
222
223 pub fn is_empty(&self) -> bool {
224 self.nodes.is_empty()
225 }
226}
227
228mod rand {
229 use std::cell::Cell;
230
231 thread_local! {
232 static RNG: Cell<u64> = Cell::new(0x123456789abcdef0);
233 }
234
235 pub fn random<T: From<f32>>() -> T {
236 RNG.with(|rng| {
237 let mut x = rng.get();
238 x ^= x << 13;
239 x ^= x >> 7;
240 x ^= x << 17;
241 rng.set(x);
242 T::from((x as f32) / (u64::MAX as f32))
243 })
244 }
245}