1use rusqlite::OptionalExtension;
50
51use crate::hnsw::{
52 config::HnswConfig,
53 distance_metric::DistanceMetric,
54 errors::HnswError,
55 layer::HnswLayer,
56 multilayer::{LevelDistributor, MultiLayerNodeManager},
57 neighborhood::NeighborhoodSearch,
58 storage::{VectorStorage, VectorStorageStats},
59};
60#[cfg(test)]
61use crate::hnsw::{
62 config::hnsw_config,
63 errors::HnswIndexError,
64};
65
66pub struct HnswIndex {
79 pub(crate) name: String,
81
82 pub(crate) config: HnswConfig,
84
85 pub(crate) layers: Vec<HnswLayer>,
87
88 pub(crate) storage: Box<dyn VectorStorage>,
90
91 pub(crate) entry_points: Vec<u64>,
93
94 pub(crate) vector_count: usize,
96
97 pub(crate) search_engine: NeighborhoodSearch,
99
100 pub(crate) level_distributor: Option<LevelDistributor>,
103
104 pub(crate) multi_layer_manager: Option<MultiLayerNodeManager>,
107}
108
109#[derive(Debug, Clone)]
111pub struct HnswIndexStats {
112 pub vector_count: usize,
114
115 pub layer_count: usize,
117
118 pub entry_point_count: usize,
120
121 pub dimension: usize,
123
124 pub distance_metric: DistanceMetric,
126
127 pub storage_stats: VectorStorageStats,
129
130 pub layer_stats: Vec<(usize, usize, f32)>,
132}
133
134include!("index_api.rs");
137include!("index_internal.rs");
138include!("index_persist.rs");
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::graph::SqliteGraph;
144 use crate::hnsw::{DistanceMetric, HnswConfigBuilder};
145
146 #[test]
147 fn test_hnsw_index_creation() {
148 let config = HnswConfigBuilder::new()
149 .dimension(3)
150 .distance_metric(DistanceMetric::Euclidean)
151 .build()
152 .unwrap();
153
154 let hnsw = HnswIndex::new("test_index", config).unwrap();
155 let stats = hnsw.statistics().unwrap();
156
157 assert_eq!(stats.vector_count, 0);
158 assert_eq!(stats.dimension, 3);
159 assert_eq!(stats.distance_metric, DistanceMetric::Euclidean);
160 }
161
162 #[test]
163 fn test_vector_insertion() {
164 let config = hnsw_config().dimension(3).build().unwrap();
165 let mut hnsw = HnswIndex::new("test_insert", config).unwrap();
166 let vector = vec![1.0, 0.0, 0.0];
167 let metadata = serde_json::json!({"label": "test"});
168
169 let result = hnsw.insert_vector(&vector, Some(metadata));
170 println!("Insert result: {:?}", result);
171 let vector_id = result.unwrap();
172 assert!(vector_id > 0);
173
174 let stats = hnsw.statistics().unwrap();
175 assert_eq!(stats.vector_count, 1);
176 }
177
178 #[test]
179 fn test_dimension_mismatch_error() {
180 let mut hnsw = HnswIndex::new("test_dim_error", HnswConfig::default()).unwrap();
181 let wrong_vector = vec![1.0, 0.0]; let result = hnsw.insert_vector(&wrong_vector, None);
184 assert!(result.is_err());
185
186 let error = result.unwrap_err();
187 assert!(matches!(
188 error,
189 HnswError::Index(HnswIndexError::VectorDimensionMismatch { .. })
190 ));
191 }
192
193 #[test]
194 fn test_empty_search() {
195 let hnsw = HnswIndex::new("test_empty_search", HnswConfig::default()).unwrap();
196 let query = vec![1.0; 768];
197
198 let results = hnsw.search(&query, 5).unwrap();
199 assert!(results.is_empty());
200 }
201
202 #[test]
203 fn test_vector_retrieval() {
204 let config = hnsw_config().dimension(3).build().unwrap();
205 let mut hnsw = HnswIndex::new("test_retrieval", config).unwrap();
206 let vector = vec![1.0, 0.0, 0.0];
207 let metadata = serde_json::json!({"label": "test"});
208
209 let vector_id = hnsw.insert_vector(&vector, Some(metadata.clone())).unwrap();
210 let result = hnsw.get_vector(vector_id).unwrap();
211
212 assert!(result.is_some());
213 let (retrieved_vector, retrieved_metadata) = result.unwrap();
214 assert_eq!(retrieved_vector, vector);
215 assert_eq!(retrieved_metadata, metadata);
216 }
217
218 #[test]
219 fn test_sqlite_graph_integration() {
220 let graph = SqliteGraph::open_in_memory().unwrap();
221 let config = HnswConfigBuilder::new()
222 .dimension(4)
223 .distance_metric(DistanceMetric::Cosine)
224 .build()
225 .unwrap();
226
227 let mut hnsw_indexes = graph.hnsw_index("test_index", config).unwrap();
228 let hnsw = hnsw_indexes.get("test_index").unwrap();
229 let stats = hnsw.statistics().unwrap();
230
231 assert_eq!(stats.vector_count, 0);
232 assert_eq!(stats.dimension, 4);
233 assert_eq!(stats.distance_metric, DistanceMetric::Cosine);
234 }
235
236 #[test]
237 fn test_basic_search_functionality() {
238 let mut hnsw = HnswIndex::new(
239 "test_search",
240 HnswConfigBuilder::new()
241 .dimension(2)
242 .m_connections(4)
243 .distance_metric(DistanceMetric::Euclidean)
244 .build()
245 .unwrap(),
246 )
247 .unwrap();
248
249 let vectors = vec![
251 vec![1.0, 0.0],
252 vec![0.0, 1.0],
253 vec![-1.0, 0.0],
254 vec![0.0, -1.0],
255 ];
256
257 let mut vector_ids = Vec::new();
258 for vector in vectors {
259 let id = hnsw.insert_vector(&vector, None).unwrap();
260 vector_ids.push(id);
261 }
262
263 let query = vec![0.9, 0.1];
265 let results = hnsw.search(&query, 2).unwrap();
266
267 assert!(!results.is_empty());
268 assert!(results.len() <= 2);
269
270 for window in results.windows(2) {
272 assert!(window[0].1 <= window[1].1);
273 }
274 }
275
276 #[test]
277 fn test_index_statistics() {
278 let mut hnsw = HnswIndex::new(
279 "test_stats",
280 HnswConfigBuilder::new()
281 .dimension(3)
282 .max_layers(3)
283 .distance_metric(DistanceMetric::Euclidean) .build()
285 .unwrap(),
286 )
287 .unwrap();
288
289 for i in 1..=5 {
291 let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
292 hnsw.insert_vector(&vector, None).unwrap();
293 }
294
295 let stats = hnsw.statistics().unwrap();
296 assert_eq!(stats.vector_count, 5);
297 assert_eq!(stats.layer_count, 3);
298 assert_eq!(stats.dimension, 3);
299 assert!(!stats.layer_stats.is_empty());
300 }
301
302 #[test]
303 fn test_metadata_persistence() {
304 use rusqlite::Connection;
305 use std::fs;
306
307 let test_dir = "/tmp/test_hnsw_metadata_persistence";
308 let db_path = format!("{}/test.db", test_dir);
309
310 let _ = fs::remove_dir_all(test_dir);
312
313 fs::create_dir_all(test_dir).unwrap();
315
316 {
318 let graph = SqliteGraph::open(&db_path).unwrap();
319 let config = HnswConfigBuilder::new()
320 .dimension(128)
321 .distance_metric(DistanceMetric::Euclidean)
322 .build()
323 .unwrap();
324
325 let mut hnsw_indexes = graph.hnsw_index("persist_test", config).unwrap();
326 let hnsw = hnsw_indexes.get("persist_test").unwrap();
327
328 assert_eq!(hnsw.name(), "persist_test");
330 assert_eq!(hnsw.config().dimension, 128);
331 assert_eq!(hnsw.config().distance_metric, DistanceMetric::Euclidean);
332
333 let conn = graph.connection();
335 let conn_ref = conn.underlying();
336 hnsw.save_metadata(conn_ref).unwrap();
337 }
338
339 {
341 let graph2 = SqliteGraph::open(&db_path).unwrap();
342
343 let index_names = graph2.list_hnsw_indexes().unwrap();
345 assert_eq!(index_names, vec!["persist_test".to_string()]);
346
347 let loaded_hnsw = graph2
349 .get_hnsw_index_ref("persist_test", |hnsw| {
350 assert_eq!(hnsw.name(), "persist_test");
351 assert_eq!(hnsw.config().dimension, 128);
352 assert_eq!(hnsw.config().distance_metric, DistanceMetric::Euclidean);
353 hnsw.config().dimension
354 })
355 .unwrap();
356
357 assert_eq!(loaded_hnsw, 128);
358 }
359
360 let _ = fs::remove_dir_all(test_dir);
362 }
363
364 #[test]
365 fn test_vector_loading_and_rebuild() {
366 use rusqlite::Connection;
367 use std::fs;
368
369 let test_dir = "/tmp/test_hnsw_vector_loading";
370 let db_path = format!("{}/test.db", test_dir);
371
372 let _ = fs::remove_dir_all(test_dir);
374
375 fs::create_dir_all(test_dir).unwrap();
377
378 {
380 let conn = Connection::open(&db_path).unwrap();
381
382 crate::schema::ensure_schema(&conn).unwrap();
384
385 conn.execute(
387 "INSERT INTO hnsw_indexes (name, dimension, m, ef_construction, distance_metric, vector_count, created_at, updated_at)
388 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
389 rusqlite::params!["load_test", 3, 16, 200, "euclidean", 5, 1000, 1000],
390 ).unwrap();
391
392 let index_id = conn.last_insert_rowid();
393
394 for i in 0..5 {
396 let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
397 let vector_bytes = bytemuck::cast_slice::<f32, u8>(&vector).to_vec();
398
399 conn.execute(
400 "INSERT INTO hnsw_vectors (index_id, vector_data, metadata, created_at, updated_at)
401 VALUES (?1, ?2, ?3, ?4, ?5)",
402 rusqlite::params![index_id, vector_bytes, None::<String>, 1000, 1000],
403 ).unwrap();
404 }
405 }
406
407 {
409 let conn2 = Connection::open(&db_path).unwrap();
410 crate::schema::ensure_schema(&conn2).unwrap();
411
412 let hnsw_metadata = HnswIndex::load_metadata(&conn2, "load_test").unwrap();
414 assert_eq!(hnsw_metadata.vector_count, 5);
415 assert_eq!(hnsw_metadata.storage.vector_count().unwrap(), 0); let mut hnsw_loaded = HnswIndex::load_with_vectors(&conn2, "load_test").unwrap();
419 assert_eq!(hnsw_loaded.vector_count, 5);
420 assert_eq!(hnsw_loaded.storage.vector_count().unwrap(), 5); let (vector, _) = hnsw_loaded.get_vector(1).unwrap().unwrap();
424 assert_eq!(vector, vec![0.0, 0.0, 0.0]);
425
426 let query = vec![2.0, 4.0, 6.0];
428 let results = hnsw_loaded.search(&query, 3).unwrap();
429 assert!(!results.is_empty());
430 }
431
432 let _ = fs::remove_dir_all(test_dir);
434 }
435
436 #[test]
437 fn test_e2e_hnsw_persistence() {
438 use rusqlite::Connection;
439 use std::fs;
440
441 let test_dir = "/tmp/test_hnsw_e2e_persistence";
442 let db_path = format!("{}/test.db", test_dir);
443
444 let _ = fs::remove_dir_all(test_dir);
446
447 fs::create_dir_all(test_dir).unwrap();
449
450 {
452 let conn = Connection::open(&db_path).unwrap();
453
454 crate::schema::ensure_schema(&conn).unwrap();
456
457 conn.execute(
459 "INSERT INTO hnsw_indexes (name, dimension, m, ef_construction, distance_metric, vector_count, created_at, updated_at)
460 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
461 rusqlite::params!["e2e_test", 3, 16, 200, "euclidean", 5, 1000, 1000],
462 ).unwrap();
463
464 let index_id = conn.last_insert_rowid();
465
466 for i in 0..5 {
468 let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
469 let vector_bytes = bytemuck::cast_slice::<f32, u8>(&vector).to_vec();
470 let metadata = serde_json::json!({"label": format!("vector_{}", i)}).to_string();
471
472 conn.execute(
473 "INSERT INTO hnsw_vectors (index_id, vector_data, metadata, created_at, updated_at)
474 VALUES (?1, ?2, ?3, ?4, ?5)",
475 rusqlite::params![index_id, vector_bytes, metadata, 1000, 1000],
476 ).unwrap();
477 }
478 }
479
480 {
482 let graph = SqliteGraph::open(&db_path).unwrap();
483
484 let index_names = graph.list_hnsw_indexes().unwrap();
486 assert_eq!(index_names, vec!["e2e_test".to_string()]);
487
488 let loaded_count = graph
490 .get_hnsw_index_ref("e2e_test", |hnsw| {
491 assert_eq!(hnsw.vector_count(), 5);
493
494 let (vector, metadata) = hnsw.get_vector(1).unwrap().unwrap();
496 assert_eq!(vector, vec![0.0, 0.0, 0.0]);
497 assert_eq!(metadata, serde_json::json!({"label": "vector_0"}));
498
499 let query = vec![2.0, 4.0, 6.0];
501 let results = hnsw.search(&query, 3).unwrap();
502 assert!(!results.is_empty());
503
504 hnsw.vector_count()
505 })
506 .unwrap();
507
508 assert_eq!(loaded_count, 5);
509 }
510
511 let _ = fs::remove_dir_all(test_dir);
513 }
514
515 #[test]
516 fn test_multilayer_level_distribution() {
517 let config = HnswConfig {
519 dimension: 4,
520 m: 16,
521 ef_construction: 200,
522 ef_search: 50,
523 ml: 4,
524 distance_metric: DistanceMetric::Euclidean,
525 enable_multilayer: true,
526 multilayer_level_distribution_base: Some(16),
527 multilayer_deterministic_seed: Some(42),
528 };
529
530 let mut hnsw = HnswIndex::new("test_multilayer_dist", config).unwrap();
531
532 assert!(
534 hnsw.has_level_distributor(),
535 "LevelDistributor should be initialized in multi-layer mode"
536 );
537
538 use crate::hnsw::multilayer::LevelDistributor;
540 let mut distributor = LevelDistributor::new(16.0, 4).with_seed(42);
541
542 let mut level_counts = vec![0; 4];
543 for _ in 0..1000 {
544 let level = distributor.sample_level_internal();
545 level_counts[level] += 1;
546 }
547
548 assert!(
556 level_counts[0] >= 900 && level_counts[0] <= 950,
557 "Level 0 should have ~938 samples, got {}",
558 level_counts[0]
559 );
560
561 assert!(
563 level_counts[1] >= 40 && level_counts[1] <= 80,
564 "Level 1 should have ~62 samples, got {}",
565 level_counts[1]
566 );
567
568 assert!(
570 level_counts[2] >= 1 && level_counts[2] <= 10,
571 "Level 2 should have ~4 samples, got {}",
572 level_counts[2]
573 );
574
575 println!(
576 "Level distribution (direct sampling): L0={}, L1={}, L2={}, L3={}",
577 level_counts[0], level_counts[1], level_counts[2], level_counts[3]
578 );
579
580 }
588
589 #[test]
590 fn test_single_layer_mode() {
591 let config = HnswConfig {
593 dimension: 4,
594 m: 16,
595 ef_construction: 200,
596 ef_search: 50,
597 ml: 4,
598 distance_metric: DistanceMetric::Euclidean,
599 enable_multilayer: false, multilayer_level_distribution_base: None,
601 multilayer_deterministic_seed: None,
602 };
603
604 let hnsw = HnswIndex::new("test_single_layer", config.clone()).unwrap();
605
606 assert!(
608 !hnsw.has_level_distributor(),
609 "LevelDistributor should not be initialized in single-layer mode"
610 );
611
612 let test_vector = vec![1.0, 0.0, 0.0, 0.0];
614 let mut hnsw_mut = HnswIndex::new("test_single_layer_mut", config).unwrap();
615 for _ in 0..100 {
616 hnsw_mut.insert_vector(&test_vector, None).unwrap();
617 }
618
619 let stats = hnsw_mut.statistics().unwrap();
620
621 assert_eq!(
623 stats.layer_stats[0].0, 100,
624 "Layer 0 should have 100 vectors"
625 );
626
627 assert_eq!(
629 stats.layer_stats[1].0, 0,
630 "Layer 1 should be empty in single-layer mode"
631 );
632 assert_eq!(
633 stats.layer_stats[2].0, 0,
634 "Layer 2 should be empty in single-layer mode"
635 );
636 assert_eq!(
637 stats.layer_stats[3].0, 0,
638 "Layer 3 should be empty in single-layer mode"
639 );
640 }
641
642 #[test]
643 fn test_multilayer_recall() {
644 use std::collections::HashSet;
645
646 let config = HnswConfig {
647 dimension: 64,
648 m: 16,
649 ef_construction: 200,
650 ef_search: 50,
651 ml: 16,
652 distance_metric: DistanceMetric::Euclidean,
653 enable_multilayer: true, multilayer_level_distribution_base: Some(16),
655 multilayer_deterministic_seed: Some(42),
656 };
657
658 let mut hnsw = HnswIndex::new("recall_test_unique", config).unwrap();
659 let mut vectors = Vec::new();
660
661 for i in 0..1000 {
663 let vector: Vec<f32> = (0..64)
664 .map(|j| ((i * 64 + j) as f32 * 0.01).cos())
665 .collect();
666 vectors.push(vector.clone());
667 hnsw.insert_vector(&vector, None).unwrap();
668 }
669
670 let k = 10;
671 let query = &vectors[0];
672
673 let hnsw_results = hnsw.search(query, k).unwrap();
675 let hnsw_ids: HashSet<_> = hnsw_results.iter().map(|(id, _)| *id).collect();
676
677 fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
679 a.iter()
680 .zip(b.iter())
681 .map(|(x, y)| (x - y).powi(2))
682 .sum::<f32>()
683 .sqrt()
684 }
685
686 let mut exact_results: Vec<_> = vectors
687 .iter()
688 .enumerate()
689 .map(|(i, v)| (i as u64 + 1, euclidean_distance(query, v)))
690 .collect();
691
692 for i in 0..exact_results.len() {
694 let mut min_idx = i;
695 for j in (i + 1)..exact_results.len() {
696 if exact_results[j].1 < exact_results[min_idx].1 {
697 min_idx = j;
698 }
699 }
700 if min_idx != i {
701 let temp = exact_results[i];
702 exact_results[i] = exact_results[min_idx];
703 exact_results[min_idx] = temp;
704 }
705 }
706
707 let exact_ids: HashSet<_> = exact_results.iter().take(k).map(|(id, _)| *id).collect();
708
709 let overlap = hnsw_ids.intersection(&exact_ids).count();
711 let recall = (overlap as f64 / k as f64) * 100.0;
712
713 println!("HNSW results: {:?}", hnsw_results);
714 println!("Exact top {}: {:?}", k, exact_ids);
715 println!("Recall: {:.1}% ({}/{})", recall, overlap, k);
716 assert!(
717 recall >= 90.0,
718 "Recall {:.1}% is below 90% threshold",
719 recall
720 );
721 }
722
723 #[test]
724 fn test_multilayer_search_complexity_ologn() {
725 use std::time::Instant;
726
727 let sizes = vec![100, 1000, 10000];
729 let mut search_times = Vec::new();
730
731 for size in sizes {
732 let config = HnswConfig {
733 dimension: 64,
734 m: 16,
735 ef_construction: 200,
736 ef_search: 50,
737 ml: 16,
738 distance_metric: DistanceMetric::Euclidean,
739 enable_multilayer: true,
740 multilayer_level_distribution_base: Some(16),
741 multilayer_deterministic_seed: Some(42),
742 };
743
744 let mut hnsw = HnswIndex::new(&format!("complexity_test_{}", size), config).unwrap();
745
746 for i in 0..size {
748 let vector: Vec<f32> = (0..64)
749 .map(|j| ((i * 64 + j) as f32 * 0.01).sin())
750 .collect();
751 hnsw.insert_vector(&vector, None).unwrap();
752 }
753
754 let query: Vec<f32> = (0..64).map(|j| (j as f32 * 0.01).sin()).collect();
756 let iterations = 10;
757 let start = Instant::now();
758 for _ in 0..iterations {
759 let _ = hnsw.search(&query, 10).unwrap();
760 }
761 let elapsed = start.elapsed();
762 let avg_time_ns = elapsed.as_nanos() / iterations as u128;
763 search_times.push((size, avg_time_ns));
764
765 println!("Size {}: avg search time = {} ns", size, avg_time_ns);
766 }
767
768 let ratio_100_to_1000 = search_times[1].1 as f64 / search_times[0].1 as f64;
771 println!("Time ratio (1000/100): {:.2}x", ratio_100_to_1000);
772 assert!(
773 ratio_100_to_1000 < 10.0,
774 "Search time ratio {:.2}x suggests worse than log scaling; expected < 10x for O(log N)",
775 ratio_100_to_1000
776 );
777
778 let ratio_1000_to_10000 = search_times[2].1 as f64 / search_times[1].1 as f64;
781 println!("Time ratio (10000/1000): {:.2}x", ratio_1000_to_10000);
782 assert!(
783 ratio_1000_to_10000 < 10.0,
784 "Search time ratio {:.2}x suggests worse than log scaling; expected < 10x for O(log N)",
785 ratio_1000_to_10000
786 );
787
788 let overall_ratio = search_times[2].1 as f64 / search_times[0].1 as f64;
790 println!("Overall time ratio (10000/100): {:.2}x", overall_ratio);
791 assert!(
792 overall_ratio < 50.0,
793 "Overall search time ratio {:.2}x suggests linear scaling; expected < 50x for O(log N) (linear would be 100x)",
794 overall_ratio
795 );
796 }
797
798 #[test]
799 fn test_multilayer_insert_layers_correct() {
800 let config = HnswConfig {
801 dimension: 64,
802 m: 16,
803 ef_construction: 200,
804 ef_search: 50,
805 ml: 16,
806 distance_metric: DistanceMetric::Euclidean,
807 enable_multilayer: true,
808 multilayer_level_distribution_base: Some(16),
809 multilayer_deterministic_seed: Some(42),
810 };
811
812 let mut hnsw = HnswIndex::new("test_layers", config).unwrap();
813
814 for i in 0..100 {
816 let vector: Vec<f32> = (0..64)
817 .map(|j| ((i * 64 + j) as f32 * 0.01).cos())
818 .collect();
819 hnsw.insert_vector(&vector, None).unwrap();
820 }
821
822 let stats = hnsw.statistics().unwrap();
824
825 println!("Layer stats: {:?}", stats.layer_stats);
826
827 assert_eq!(
829 stats.layer_stats[0].0, 100,
830 "Layer 0 should have all 100 vectors"
831 );
832
833 let layer1_count = stats.layer_stats[1].0;
836 assert!(
837 layer1_count > 0 && layer1_count < 20,
838 "Layer 1 should have some vectors (got {}), but not all",
839 layer1_count
840 );
841
842 assert!(
844 stats.layer_stats[0].0 >= stats.layer_stats[1].0,
845 "Layer 0 should have >= Layer 1"
846 );
847 assert!(
848 stats.layer_stats[1].0 >= stats.layer_stats[2].0,
849 "Layer 1 should have >= Layer 2"
850 );
851
852 assert!(
854 hnsw.has_level_distributor(),
855 "LevelDistributor should be initialized"
856 );
857 }
858}