1use std::fmt;
2
3use bytemuck::cast_slice;
4use half::f16;
5
6use crate::distance::{DistanceMetric, vtype_to_scalar_kind};
7use crate::types::VectorType;
8
9#[derive(Debug, Clone, Copy)]
11pub struct HnswParams {
12 pub m: usize,
13 pub ef_construction: usize,
14 pub ef_search: usize,
15}
16
17impl Default for HnswParams {
18 fn default() -> Self {
19 Self {
20 m: 16,
21 ef_construction: 200,
22 ef_search: 64,
23 }
24 }
25}
26
27#[derive(Debug)]
28pub struct IndexError(pub String);
29
30impl fmt::Display for IndexError {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 write!(f, "index error: {}", self.0)
33 }
34}
35
36impl std::error::Error for IndexError {}
37
38pub struct HnswIndex {
40 inner: usearch::Index,
41 _dim: usize,
42 vtype: VectorType,
43}
44
45impl HnswIndex {
46 pub fn new(
48 dim: usize,
49 vtype: VectorType,
50 metric: DistanceMetric,
51 params: Option<HnswParams>,
52 ) -> Result<Self, IndexError> {
53 let p = params.unwrap_or_default();
54 let opts = usearch::IndexOptions {
55 dimensions: dim,
56 metric: metric.to_usearch(),
57 quantization: vtype_to_scalar_kind(vtype),
58 connectivity: p.m,
59 expansion_add: p.ef_construction,
60 expansion_search: p.ef_search,
61 multi: false,
62 };
63 let inner = usearch::Index::new(&opts).map_err(|e| IndexError(e.to_string()))?;
64 Ok(Self {
65 inner,
66 _dim: dim,
67 vtype,
68 })
69 }
70
71 pub fn len(&self) -> usize {
73 self.inner.size()
74 }
75
76 pub fn is_empty(&self) -> bool {
77 self.len() == 0
78 }
79
80 pub fn add(&self, key: u64, blob: &[u8]) -> Result<(), IndexError> {
82 self.reserve_if_needed()?;
83 match self.vtype {
84 VectorType::Float4 => {
85 let v: &[f32] = cast_slice(blob);
86 self.inner
87 .add(key, v)
88 .map_err(|e| IndexError(e.to_string()))
89 }
90 VectorType::Float8 => {
91 let v: &[f64] = cast_slice(blob);
92 self.inner
93 .add(key, v)
94 .map_err(|e| IndexError(e.to_string()))
95 }
96 VectorType::Int1 => {
97 let v: &[i8] = cast_slice(blob);
98 self.inner
99 .add(key, v)
100 .map_err(|e| IndexError(e.to_string()))
101 }
102 VectorType::Float2 => {
105 let v: &[f16] = cast_slice(blob);
106 let f: Vec<f32> = v.iter().map(|x| x.to_f32()).collect();
107 self.inner
108 .add(key, &f)
109 .map_err(|e| IndexError(e.to_string()))
110 }
111 VectorType::Int2 => {
112 let v: &[i16] = cast_slice(blob);
113 let f: Vec<f32> = v.iter().map(|x| *x as f32).collect();
114 self.inner
115 .add(key, &f)
116 .map_err(|e| IndexError(e.to_string()))
117 }
118 VectorType::Int4 => {
119 let v: &[i32] = cast_slice(blob);
120 let f: Vec<f32> = v.iter().map(|x| *x as f32).collect();
121 self.inner
122 .add(key, &f)
123 .map_err(|e| IndexError(e.to_string()))
124 }
125 }
126 }
127
128 pub fn search(&self, query_blob: &[u8], k: usize) -> Result<Vec<(u64, f32)>, IndexError> {
131 if self.is_empty() {
132 return Ok(Vec::new());
133 }
134
135 let matches = match self.vtype {
136 VectorType::Float4 => {
137 let q: &[f32] = cast_slice(query_blob);
138 self.inner.search(q, k)
139 }
140 VectorType::Float8 => {
141 let q: &[f64] = cast_slice(query_blob);
142 self.inner.search(q, k)
143 }
144 VectorType::Int1 => {
145 let q: &[i8] = cast_slice(query_blob);
146 self.inner.search(q, k)
147 }
148 VectorType::Float2 => {
149 let q: &[f16] = cast_slice(query_blob);
150 let f: Vec<f32> = q.iter().map(|x| x.to_f32()).collect();
151 self.inner.search(&f, k)
152 }
153 VectorType::Int2 => {
154 let q: &[i16] = cast_slice(query_blob);
155 let f: Vec<f32> = q.iter().map(|x| *x as f32).collect();
156 self.inner.search(&f, k)
157 }
158 VectorType::Int4 => {
159 let q: &[i32] = cast_slice(query_blob);
160 let f: Vec<f32> = q.iter().map(|x| *x as f32).collect();
161 self.inner.search(&f, k)
162 }
163 }
164 .map_err(|e| IndexError(e.to_string()))?;
165
166 Ok(matches.keys.into_iter().zip(matches.distances).collect())
167 }
168
169 pub fn remove(&self, key: u64) -> Result<(), IndexError> {
171 self.inner
172 .remove(key)
173 .map(|_| ())
174 .map_err(|e| IndexError(e.to_string()))
175 }
176
177 pub fn save_to_buffer(&self) -> Result<Vec<u8>, IndexError> {
179 let len = self.inner.serialized_length();
180 let mut buf = vec![0u8; len];
181 self.inner
182 .save_to_buffer(&mut buf)
183 .map_err(|e| IndexError(e.to_string()))?;
184 Ok(buf)
185 }
186
187 pub fn load_from_buffer(&self, buf: &[u8]) -> Result<(), IndexError> {
189 self.inner
190 .load_from_buffer(buf)
191 .map_err(|e| IndexError(e.to_string()))
192 }
193
194 fn reserve_if_needed(&self) -> Result<(), IndexError> {
196 if self.inner.size() >= self.inner.capacity() {
197 let new_cap = (self.inner.capacity() * 2).max(64);
198 self.inner
199 .reserve(new_cap)
200 .map_err(|e| IndexError(e.to_string()))?;
201 }
202 Ok(())
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use bytemuck::cast_slice;
210
211 fn f32_blob(values: &[f32]) -> Vec<u8> {
216 cast_slice(values).to_vec()
217 }
218
219 fn f64_blob(values: &[f64]) -> Vec<u8> {
220 cast_slice(values).to_vec()
221 }
222
223 #[test]
228 fn hnsw_params_default_values() {
229 let p = HnswParams::default();
230 assert_eq!(p.m, 16);
231 assert_eq!(p.ef_construction, 200);
232 assert_eq!(p.ef_search, 64);
233 }
234
235 #[test]
240 fn new_float4_l2_does_not_error() {
241 let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None);
242 assert!(idx.is_ok(), "expected Ok, got {:?}", idx.err());
243 }
244
245 #[test]
246 fn new_float8_cosine_does_not_error() {
247 let idx = HnswIndex::new(4, VectorType::Float8, DistanceMetric::Cosine, None);
248 assert!(idx.is_ok(), "expected Ok, got {:?}", idx.err());
249 }
250
251 #[test]
252 fn new_with_custom_params_does_not_error() {
253 let params = HnswParams {
254 m: 8,
255 ef_construction: 64,
256 ef_search: 32,
257 };
258 let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, Some(params));
259 assert!(idx.is_ok(), "expected Ok, got {:?}", idx.err());
260 }
261
262 #[test]
267 fn len_empty_index_is_zero() {
268 let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
269 assert_eq!(idx.len(), 0);
270 assert!(idx.is_empty());
271 }
272
273 #[test]
274 fn len_increases_after_add() {
275 let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
276
277 idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
278 assert_eq!(idx.len(), 1);
279 assert!(!idx.is_empty());
280
281 idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
282 assert_eq!(idx.len(), 2);
283
284 idx.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
285 assert_eq!(idx.len(), 3);
286 }
287
288 #[test]
293 fn search_nearest_orthogonal_float4() {
294 let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
295
296 idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
297 idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
298 idx.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
299
300 let results = idx.search(&f32_blob(&[0.9, 0.1, 0.0]), 1).unwrap();
302 assert_eq!(results.len(), 1);
303 assert_eq!(
304 results[0].0, 1,
305 "expected key 1 ([1,0,0]) as nearest, got key {}",
306 results[0].0
307 );
308 }
309
310 #[test]
311 fn search_returns_empty_on_empty_index() {
312 let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
313 let results = idx.search(&f32_blob(&[1.0, 0.0, 0.0]), 5).unwrap();
314 assert!(results.is_empty());
315 }
316
317 #[test]
318 fn search_k_larger_than_index_returns_all_vectors() {
319 let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
320 idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
321 idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
322
323 let results = idx.search(&f32_blob(&[1.0, 0.0, 0.0]), 10).unwrap();
325 assert_eq!(results.len(), 2);
326 }
327
328 #[test]
333 fn remove_decreases_len() {
334 let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
335
336 idx.add(10, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
337 idx.add(20, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
338 idx.add(30, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
339 assert_eq!(idx.len(), 3);
340
341 idx.remove(20).unwrap();
342 assert_eq!(idx.len(), 2);
343 }
344
345 #[test]
346 fn remove_key_no_longer_returned_by_search() {
347 let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
348
349 idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
350 idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
351 idx.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
352
353 idx.remove(2).unwrap();
355
356 let results = idx.search(&f32_blob(&[0.0, 1.0, 0.0]), 3).unwrap();
357 let returned_keys: Vec<u64> = results.iter().map(|(k, _)| *k).collect();
358 assert!(
359 !returned_keys.contains(&2),
360 "removed key 2 should not appear in search results, got {:?}",
361 returned_keys
362 );
363 }
364
365 #[test]
370 fn save_load_roundtrip_float4() {
371 let src = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
373 src.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
374 src.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
375 src.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
376
377 let buf = src.save_to_buffer().unwrap();
378 assert!(!buf.is_empty(), "serialized buffer must not be empty");
379
380 let dst = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
382 dst.load_from_buffer(&buf).unwrap();
383
384 assert_eq!(dst.len(), src.len());
386
387 let results = dst.search(&f32_blob(&[0.9, 0.1, 0.0]), 1).unwrap();
389 assert_eq!(results.len(), 1);
390 assert_eq!(
391 results[0].0, 1,
392 "post-load search should return key 1, got {}",
393 results[0].0
394 );
395 }
396
397 #[test]
402 fn add_search_float8() {
403 let idx = HnswIndex::new(3, VectorType::Float8, DistanceMetric::L2, None).unwrap();
404
405 idx.add(1, &f64_blob(&[1.0, 0.0, 0.0])).unwrap();
406 idx.add(2, &f64_blob(&[0.0, 1.0, 0.0])).unwrap();
407 idx.add(3, &f64_blob(&[0.0, 0.0, 1.0])).unwrap();
408
409 let results = idx.search(&f64_blob(&[0.1, 0.0, 0.9]), 1).unwrap();
410 assert_eq!(results.len(), 1);
411 assert_eq!(
412 results[0].0, 3,
413 "expected key 3 ([0,0,1]) as nearest, got key {}",
414 results[0].0
415 );
416 }
417
418 #[test]
419 fn save_load_roundtrip_float8() {
420 let src = HnswIndex::new(3, VectorType::Float8, DistanceMetric::L2, None).unwrap();
421 src.add(1, &f64_blob(&[1.0, 0.0, 0.0])).unwrap();
422 src.add(2, &f64_blob(&[0.0, 1.0, 0.0])).unwrap();
423 src.add(3, &f64_blob(&[0.0, 0.0, 1.0])).unwrap();
424
425 let buf = src.save_to_buffer().unwrap();
426
427 let dst = HnswIndex::new(3, VectorType::Float8, DistanceMetric::L2, None).unwrap();
428 dst.load_from_buffer(&buf).unwrap();
429
430 assert_eq!(dst.len(), 3);
431
432 let results = dst.search(&f64_blob(&[0.0, 0.9, 0.1]), 1).unwrap();
433 assert_eq!(results.len(), 1);
434 assert_eq!(
435 results[0].0, 2,
436 "post-load search should return key 2, got {}",
437 results[0].0
438 );
439 }
440
441 #[test]
446 fn custom_params_index_behaves_correctly() {
447 let params = HnswParams {
448 m: 4,
449 ef_construction: 32,
450 ef_search: 16,
451 };
452 let idx =
453 HnswIndex::new(3, VectorType::Float4, DistanceMetric::Cosine, Some(params)).unwrap();
454
455 idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
456 idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
457 idx.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
458
459 assert_eq!(idx.len(), 3);
460
461 let results = idx.search(&f32_blob(&[0.0, 0.1, 0.9]), 1).unwrap();
462 assert_eq!(results.len(), 1);
463 assert_eq!(
464 results[0].0, 3,
465 "expected key 3 ([0,0,1]) as nearest under cosine, got {}",
466 results[0].0
467 );
468 }
469}