1use std::cmp::Reverse;
28use std::collections::{BinaryHeap, HashSet};
29use std::sync::{Arc, RwLock};
30
31use chrono::{DateTime, Utc};
32use rand::Rng;
33use serde::{Deserialize, Serialize};
34
35use crate::ruvector_native::SemanticVector;
36use crate::FrameworkError;
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct HnswConfig {
41 pub m: usize,
45
46 pub m_max_0: usize,
48
49 pub ef_construction: usize,
53
54 pub ef_search: usize,
58
59 pub ml: f64,
63
64 pub dimension: usize,
66
67 pub metric: DistanceMetric,
69}
70
71impl Default for HnswConfig {
72 fn default() -> Self {
73 let m = 16;
74 Self {
75 m,
76 m_max_0: m * 2,
77 ef_construction: 200,
78 ef_search: 50,
79 ml: 1.0 / (m as f64).ln(),
80 dimension: 128,
81 metric: DistanceMetric::Cosine,
82 }
83 }
84}
85
86#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
88pub enum DistanceMetric {
89 Cosine,
93
94 Euclidean,
96
97 Manhattan,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103struct HnswNode {
104 vector: Vec<f32>,
106
107 external_id: String,
109
110 timestamp: DateTime<Utc>,
112
113 level: usize,
115
116 connections: Vec<Vec<usize>>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct HnswSearchResult {
124 pub node_id: usize,
126
127 pub external_id: String,
129
130 pub distance: f32,
132
133 pub similarity: Option<f32>,
135
136 pub timestamp: DateTime<Utc>,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct HnswStats {
143 pub node_count: usize,
145
146 pub layer_count: usize,
148
149 pub nodes_per_layer: Vec<usize>,
151
152 pub avg_connections_per_layer: Vec<f64>,
154
155 pub total_edges: usize,
157
158 pub entry_point: Option<usize>,
160
161 pub estimated_memory_bytes: usize,
163}
164
165pub struct HnswIndex {
169 config: HnswConfig,
171
172 nodes: Vec<HnswNode>,
174
175 entry_point: Option<usize>,
177
178 max_layer: usize,
180
181 rng: Arc<RwLock<rand::rngs::StdRng>>,
183}
184
185impl HnswIndex {
186 pub fn new() -> Self {
188 Self::with_config(HnswConfig::default())
189 }
190
191 pub fn with_config(config: HnswConfig) -> Self {
193 use rand::SeedableRng;
194 Self {
195 config,
196 nodes: Vec::new(),
197 entry_point: None,
198 max_layer: 0,
199 rng: Arc::new(RwLock::new(rand::rngs::StdRng::from_entropy())),
200 }
201 }
202
203 pub fn insert(&mut self, vector: SemanticVector) -> Result<usize, FrameworkError> {
213 if vector.embedding.len() != self.config.dimension {
214 return Err(FrameworkError::Config(format!(
215 "Vector dimension mismatch: expected {}, got {}",
216 self.config.dimension,
217 vector.embedding.len()
218 )));
219 }
220
221 let node_id = self.nodes.len();
222 let level = self.random_level();
223
224 let mut new_node = HnswNode {
226 vector: vector.embedding,
227 external_id: vector.id,
228 timestamp: vector.timestamp,
229 level,
230 connections: vec![Vec::new(); level + 1],
231 };
232
233 if self.entry_point.is_none() {
235 self.nodes.push(new_node);
237 self.entry_point = Some(node_id);
238 self.max_layer = level;
239 return Ok(node_id);
240 }
241
242 let entry_point = self.entry_point.unwrap();
244 let mut current_nearest = vec![entry_point];
245
246 for lc in (level + 1..=self.max_layer).rev() {
248 current_nearest = self.search_layer(&new_node.vector, ¤t_nearest, 1, lc);
249 }
250
251 for lc in (0..=level).rev() {
253 let candidates = self.search_layer(&new_node.vector, ¤t_nearest, self.config.ef_construction, lc);
254
255 let m = if lc == 0 { self.config.m_max_0 } else { self.config.m };
257 let neighbors = self.select_neighbors(&new_node.vector, candidates, m);
258
259 for &neighbor_id in &neighbors {
261 new_node.connections[lc].push(neighbor_id);
263 }
264
265 current_nearest = neighbors.clone();
266 }
267
268 self.nodes.push(new_node);
269
270 for lc in 0..=level {
272 let neighbors: Vec<usize> = self.nodes[node_id].connections[lc].clone();
273 for neighbor_id in neighbors {
274 if lc < self.nodes[neighbor_id].connections.len() {
276 self.nodes[neighbor_id].connections[lc].push(node_id);
277
278 let m_max = if lc == 0 { self.config.m_max_0 } else { self.config.m };
280 if self.nodes[neighbor_id].connections[lc].len() > m_max {
281 let neighbor_vec = self.nodes[neighbor_id].vector.clone();
282 let candidates = self.nodes[neighbor_id].connections[lc].clone();
283 let pruned = self.select_neighbors(&neighbor_vec, candidates, m_max);
284 self.nodes[neighbor_id].connections[lc] = pruned;
285 }
286 }
287 }
288 }
289
290 if level > self.max_layer {
292 self.max_layer = level;
293 self.entry_point = Some(node_id);
294 }
295
296 Ok(node_id)
297 }
298
299 pub fn insert_batch(&mut self, vectors: Vec<SemanticVector>) -> Result<Vec<usize>, FrameworkError> {
303 let mut ids = Vec::with_capacity(vectors.len());
304 for vector in vectors {
305 ids.push(self.insert(vector)?);
306 }
307 Ok(ids)
308 }
309
310 pub fn search_knn(&self, query: &[f32], k: usize) -> Result<Vec<HnswSearchResult>, FrameworkError> {
321 if query.len() != self.config.dimension {
322 return Err(FrameworkError::Config(format!(
323 "Query dimension mismatch: expected {}, got {}",
324 self.config.dimension,
325 query.len()
326 )));
327 }
328
329 if self.entry_point.is_none() {
330 return Ok(Vec::new());
331 }
332
333 let entry_point = self.entry_point.unwrap();
334 let mut current_nearest = vec![entry_point];
335
336 for lc in (1..=self.max_layer).rev() {
338 current_nearest = self.search_layer(query, ¤t_nearest, 1, lc);
339 }
340
341 let ef = self.config.ef_search.max(k);
343 let candidates = self.search_layer(query, ¤t_nearest, ef, 0);
344
345 let results: Vec<HnswSearchResult> = candidates
347 .iter()
348 .take(k)
349 .map(|&node_id| {
350 let node = &self.nodes[node_id];
351 let distance = self.distance(query, &node.vector);
352 let similarity = if self.config.metric == DistanceMetric::Cosine {
353 Some(self.cosine_similarity(query, &node.vector))
354 } else {
355 None
356 };
357
358 HnswSearchResult {
359 node_id,
360 external_id: node.external_id.clone(),
361 distance,
362 similarity,
363 timestamp: node.timestamp,
364 }
365 })
366 .collect();
367
368 Ok(results)
369 }
370
371 pub fn search_threshold(
383 &self,
384 query: &[f32],
385 threshold: f32,
386 max_results: Option<usize>,
387 ) -> Result<Vec<HnswSearchResult>, FrameworkError> {
388 let k = max_results.unwrap_or(1000).max(100);
390 let mut results = self.search_knn(query, k)?;
391
392 results.retain(|r| r.distance < threshold);
394
395 if let Some(max) = max_results {
397 results.truncate(max);
398 }
399
400 Ok(results)
401 }
402
403 pub fn stats(&self) -> HnswStats {
405 let node_count = self.nodes.len();
406 let layer_count = self.max_layer + 1;
407
408 let mut nodes_per_layer = vec![0; layer_count];
409 let mut connections_per_layer = vec![0; layer_count];
410
411 for node in &self.nodes {
412 for layer in 0..=node.level {
413 nodes_per_layer[layer] += 1;
414 connections_per_layer[layer] += node.connections[layer].len();
415 }
416 }
417
418 let avg_connections_per_layer: Vec<f64> = connections_per_layer
419 .iter()
420 .zip(&nodes_per_layer)
421 .map(|(conn, nodes)| {
422 if *nodes > 0 {
423 *conn as f64 / *nodes as f64
424 } else {
425 0.0
426 }
427 })
428 .collect();
429
430 let total_edges: usize = connections_per_layer.iter().sum();
431
432 let estimated_memory_bytes = node_count
434 * (self.config.dimension * 4 + 100 + self.config.m * 8 * layer_count); HnswStats {
439 node_count,
440 layer_count,
441 nodes_per_layer,
442 avg_connections_per_layer,
443 total_edges,
444 entry_point: self.entry_point,
445 estimated_memory_bytes,
446 }
447 }
448
449 fn search_layer(&self, query: &[f32], entry_points: &[usize], ef: usize, layer: usize) -> Vec<usize> {
453 let mut visited = HashSet::new();
454 let mut candidates = BinaryHeap::new();
455 let mut nearest = BinaryHeap::new();
456
457 for &ep in entry_points {
458 let dist = self.distance(query, &self.nodes[ep].vector);
459 candidates.push((Reverse(OrderedFloat(dist)), ep));
460 nearest.push((OrderedFloat(dist), ep));
461 visited.insert(ep);
462 }
463
464 while let Some((Reverse(OrderedFloat(dist)), current)) = candidates.pop() {
465 if let Some(&(OrderedFloat(max_dist), _)) = nearest.peek() {
467 if dist > max_dist {
468 break;
469 }
470 }
471
472 if current < self.nodes.len() && layer <= self.nodes[current].level {
474 for &neighbor in &self.nodes[current].connections[layer] {
475 if visited.insert(neighbor) {
476 let neighbor_dist = self.distance(query, &self.nodes[neighbor].vector);
477
478 if let Some(&(OrderedFloat(max_dist), _)) = nearest.peek() {
479 if neighbor_dist < max_dist || nearest.len() < ef {
480 candidates.push((Reverse(OrderedFloat(neighbor_dist)), neighbor));
481 nearest.push((OrderedFloat(neighbor_dist), neighbor));
482
483 if nearest.len() > ef {
484 nearest.pop();
485 }
486 }
487 } else {
488 candidates.push((Reverse(OrderedFloat(neighbor_dist)), neighbor));
489 nearest.push((OrderedFloat(neighbor_dist), neighbor));
490 }
491 }
492 }
493 }
494 }
495
496 let mut sorted_nearest: Vec<_> = nearest.into_iter().collect();
498 sorted_nearest.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
499 sorted_nearest.into_iter().map(|(_, id)| id).collect()
500 }
501
502 fn select_neighbors(&self, base: &[f32], candidates: Vec<usize>, m: usize) -> Vec<usize> {
504 if candidates.len() <= m {
505 return candidates;
506 }
507
508 let mut with_distances: Vec<_> = candidates
510 .into_iter()
511 .map(|id| {
512 let dist = self.distance(base, &self.nodes[id].vector);
513 (OrderedFloat(dist), id)
514 })
515 .collect();
516
517 with_distances.sort_by_key(|(dist, _)| *dist);
518 with_distances.into_iter().take(m).map(|(_, id)| id).collect()
519 }
520
521 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
523 match self.config.metric {
524 DistanceMetric::Cosine => {
525 let similarity = self.cosine_similarity(a, b);
526 similarity.max(-1.0).min(1.0).acos() / std::f32::consts::PI
528 }
529 DistanceMetric::Euclidean => {
530 a.iter()
531 .zip(b.iter())
532 .map(|(x, y)| (x - y).powi(2))
533 .sum::<f32>()
534 .sqrt()
535 }
536 DistanceMetric::Manhattan => {
537 a.iter()
538 .zip(b.iter())
539 .map(|(x, y)| (x - y).abs())
540 .sum()
541 }
542 }
543 }
544
545 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
547 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
548 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
549 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
550
551 if norm_a == 0.0 || norm_b == 0.0 {
552 return 0.0;
553 }
554
555 (dot / (norm_a * norm_b)).max(-1.0).min(1.0)
556 }
557
558 fn random_level(&self) -> usize {
560 let mut rng = self.rng.write().unwrap();
561 let uniform: f64 = rng.gen();
562 (-uniform.ln() * self.config.ml).floor() as usize
563 }
564
565 pub fn get_vector(&self, node_id: usize) -> Option<&Vec<f32>> {
567 self.nodes.get(node_id).map(|n| &n.vector)
568 }
569
570 pub fn get_external_id(&self, node_id: usize) -> Option<&str> {
572 self.nodes.get(node_id).map(|n| n.external_id.as_str())
573 }
574
575 pub fn len(&self) -> usize {
577 self.nodes.len()
578 }
579
580 pub fn is_empty(&self) -> bool {
582 self.nodes.is_empty()
583 }
584}
585
586impl Default for HnswIndex {
587 fn default() -> Self {
588 Self::new()
589 }
590}
591
592#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
594struct OrderedFloat(f32);
595
596impl Eq for OrderedFloat {}
597
598impl Ord for OrderedFloat {
599 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
600 self.0.partial_cmp(&other.0).unwrap_or(std::cmp::Ordering::Equal)
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607 use std::collections::HashMap;
608 use crate::ruvector_native::Domain;
609
610 fn create_test_vector(id: &str, embedding: Vec<f32>) -> SemanticVector {
611 SemanticVector {
612 id: id.to_string(),
613 embedding,
614 domain: Domain::Climate,
615 timestamp: Utc::now(),
616 metadata: HashMap::new(),
617 }
618 }
619
620 #[test]
621 fn test_hnsw_basic_insert_search() {
622 let config = HnswConfig {
623 dimension: 3,
624 ..Default::default()
625 };
626 let mut index = HnswIndex::with_config(config);
627
628 let v1 = create_test_vector("v1", vec![1.0, 0.0, 0.0]);
630 let v2 = create_test_vector("v2", vec![0.0, 1.0, 0.0]);
631 let v3 = create_test_vector("v3", vec![0.9, 0.1, 0.0]);
632
633 index.insert(v1).unwrap();
634 index.insert(v2).unwrap();
635 index.insert(v3).unwrap();
636
637 assert_eq!(index.len(), 3);
638
639 let query = vec![1.0, 0.0, 0.0];
641 let results = index.search_knn(&query, 2).unwrap();
642
643 assert_eq!(results.len(), 2);
644 assert_eq!(results[0].external_id, "v1"); assert_eq!(results[1].external_id, "v3"); }
647
648 #[test]
649 fn test_hnsw_batch_insert() {
650 let config = HnswConfig {
651 dimension: 2,
652 ..Default::default()
653 };
654 let mut index = HnswIndex::with_config(config);
655
656 let vectors = vec![
657 create_test_vector("v1", vec![1.0, 0.0]),
658 create_test_vector("v2", vec![0.0, 1.0]),
659 create_test_vector("v3", vec![1.0, 1.0]),
660 ];
661
662 let ids = index.insert_batch(vectors).unwrap();
663 assert_eq!(ids.len(), 3);
664 assert_eq!(index.len(), 3);
665 }
666
667 #[test]
668 fn test_hnsw_threshold_search() {
669 let config = HnswConfig {
670 dimension: 2,
671 ..Default::default()
672 };
673 let mut index = HnswIndex::with_config(config);
674
675 index.insert(create_test_vector("close", vec![1.0, 0.1])).unwrap();
677 index.insert(create_test_vector("medium", vec![0.7, 0.7])).unwrap();
678 index.insert(create_test_vector("far", vec![0.0, 1.0])).unwrap();
679
680 let query = vec![1.0, 0.0];
681 let results = index.search_threshold(&query, 0.3, None).unwrap();
682
683 assert!(results.len() >= 1);
685 assert!(results.iter().all(|r| r.distance < 0.3));
686 }
687
688 #[test]
689 fn test_hnsw_cosine_similarity() {
690 let config = HnswConfig {
691 dimension: 3,
692 metric: DistanceMetric::Cosine,
693 ..Default::default()
694 };
695 let mut index = HnswIndex::with_config(config);
696
697 let v1 = create_test_vector("identical", vec![1.0, 0.0, 0.0]);
698 let v2 = create_test_vector("orthogonal", vec![0.0, 1.0, 0.0]);
699 let v3 = create_test_vector("opposite", vec![-1.0, 0.0, 0.0]);
700
701 index.insert(v1).unwrap();
702 index.insert(v2).unwrap();
703 index.insert(v3).unwrap();
704
705 let query = vec![1.0, 0.0, 0.0];
706 let results = index.search_knn(&query, 3).unwrap();
707
708 assert_eq!(results[0].external_id, "identical");
710 assert!(results[0].distance < 0.01);
711
712 assert_eq!(results[2].external_id, "opposite");
714 }
715
716 #[test]
717 fn test_hnsw_stats() {
718 let config = HnswConfig {
719 dimension: 2,
720 m: 4,
721 ..Default::default()
722 };
723 let mut index = HnswIndex::with_config(config);
724
725 for i in 0..10 {
726 let vec = create_test_vector(&format!("v{}", i), vec![i as f32, i as f32]);
727 index.insert(vec).unwrap();
728 }
729
730 let stats = index.stats();
731 assert_eq!(stats.node_count, 10);
732 assert!(stats.layer_count > 0);
733 assert_eq!(stats.nodes_per_layer[0], 10); assert!(stats.total_edges > 0);
735 }
736
737 #[test]
738 fn test_dimension_mismatch() {
739 let config = HnswConfig {
740 dimension: 3,
741 ..Default::default()
742 };
743 let mut index = HnswIndex::with_config(config);
744
745 let bad_vector = create_test_vector("bad", vec![1.0, 2.0]); let result = index.insert(bad_vector);
747 assert!(result.is_err());
748 }
749
750 #[test]
751 fn test_empty_index_search() {
752 let index = HnswIndex::new();
753 let query = vec![1.0; 128];
754 let results = index.search_knn(&query, 5).unwrap();
755 assert!(results.is_empty());
756 }
757}