Skip to main content

sqlitegraph/hnsw/
index.rs

1//! HNSW Vector Search Index API
2//!
3//! This module provides the main HNSW index implementation that integrates
4//! vector search capabilities with SQLiteGraph. It combines all the HNSW
5//! components (layers, neighborhood search, storage) into a cohesive API
6//! that follows SQLiteGraph patterns and conventions.
7//!
8//! # Architecture
9//!
10//! The HnswIndex serves as the main orchestrator that coordinates:
11//! - Vector storage and retrieval via VectorStorage trait
12//! - Layer management for the hierarchical graph structure
13//! - Neighborhood search for approximate nearest neighbors
14//! - Entry point management for multi-layer navigation
15//!
16//! # Integration with SQLiteGraph
17//!
18//! The HNSW index is designed to work seamlessly with SQLiteGraph:
19//! - Uses SqliteGraphError for consistent error handling
20//! - Follows SQLiteGraph method naming conventions
21//! - Integrates with existing SQLite schemas
22//! - Supports both in-memory and persistent storage
23//!
24//! # Examples
25//!
26//! ```rust
27//! use sqlitegraph::{SqliteGraph, hnsw::{HnswConfig, DistanceMetric}};
28//!
29//! let graph = SqliteGraph::open_in_memory()?;
30//! let config = HnswConfig::builder()
31//!     .dimension(768)
32//!     .distance_metric(DistanceMetric::Cosine)
33//!     .build()?;
34//!
35//! let hnsw = graph.hnsw_index("vectors", config)?;
36//!
37//! // Insert vectors with metadata
38//! let vector_id = hnsw.get_mut("vectors").unwrap()
39//!     .insert_vector(&vector_data, Some(metadata))?;
40//!
41//! // Search for similar vectors
42//! let results = hnsw.get_mut("vectors").unwrap()
43//!     .search(&query_vector, 10)?;
44//! for (id, distance) in results {
45//!     println!("Vector {}: distance {}", id, distance);
46//! }
47//! ```
48
49use rusqlite::OptionalExtension;
50
51use crate::hnsw::{
52    config::HnswConfig,
53    distance_metric::DistanceMetric,
54    errors::HnswError,
55    layer::HnswLayer,
56    multilayer::{LevelDistributor, MultiLayerNodeManager},
57    neighborhood::NeighborhoodSearch,
58    storage::{VectorStorage, VectorStorageStats},
59};
60#[cfg(test)]
61use crate::hnsw::{
62    config::hnsw_config,
63    errors::HnswIndexError,
64};
65
66/// Main HNSW vector search index
67///
68/// Provides approximate nearest neighbor search capabilities using the
69/// Hierarchical Navigable Small World algorithm. Integrates with SQLiteGraph
70/// to provide vector-augmented graph queries.
71///
72/// # Performance Characteristics
73///
74/// - **Search Time**: O(log N) average case complexity
75/// - **Memory Usage**: 2-3x vector data size overhead
76/// - **Build Time**: O(N log N) with construction parameters
77/// - **Accuracy**: 95%+ recall for typical workloads
78pub struct HnswIndex {
79    /// Name of this index (for persistence and multi-index support)
80    pub(crate) name: String,
81
82    /// HNSW configuration parameters
83    pub(crate) config: HnswConfig,
84
85    /// Layer management (0 = base layer, higher numbers = smaller layers)
86    pub(crate) layers: Vec<HnswLayer>,
87
88    /// Vector storage backend
89    pub(crate) storage: Box<dyn VectorStorage>,
90
91    /// Entry points for navigating the hierarchical structure
92    pub(crate) entry_points: Vec<u64>,
93
94    /// Number of vectors currently indexed
95    pub(crate) vector_count: usize,
96
97    /// Neighborhood search engine
98    pub(crate) search_engine: NeighborhoodSearch,
99
100    /// Level distributor for exponential level assignment in multi-layer mode
101    /// Only initialized when enable_multilayer == true
102    pub(crate) level_distributor: Option<LevelDistributor>,
103
104    /// Multi-layer node manager for tracking layer assignments and ID translation
105    /// Only initialized when enable_multilayer == true
106    pub(crate) multi_layer_manager: Option<MultiLayerNodeManager>,
107}
108
109/// Comprehensive statistics for an HNSW index
110#[derive(Debug, Clone)]
111pub struct HnswIndexStats {
112    /// Total number of vectors in the index
113    pub vector_count: usize,
114
115    /// Number of layers in the hierarchical structure
116    pub layer_count: usize,
117
118    /// Number of entry points (vectors in higher layers)
119    pub entry_point_count: usize,
120
121    /// Vector dimension
122    pub dimension: usize,
123
124    /// Distance metric being used
125    pub distance_metric: DistanceMetric,
126
127    /// Storage backend statistics
128    pub storage_stats: VectorStorageStats,
129
130    /// Per-layer statistics (node_count, total_connections, avg_connections)
131    pub layer_stats: Vec<(usize, usize, f32)>,
132}
133
134// Include split module implementations using the include! macro
135// This allows us to split the file while maintaining a single compilation unit
136include!("index_api.rs");
137include!("index_internal.rs");
138include!("index_persist.rs");
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::graph::SqliteGraph;
144    use crate::hnsw::{DistanceMetric, HnswConfigBuilder};
145
146    #[test]
147    fn test_hnsw_index_creation() {
148        let config = HnswConfigBuilder::new()
149            .dimension(3)
150            .distance_metric(DistanceMetric::Euclidean)
151            .build()
152            .unwrap();
153
154        let hnsw = HnswIndex::new("test_index", config).unwrap();
155        let stats = hnsw.statistics().unwrap();
156
157        assert_eq!(stats.vector_count, 0);
158        assert_eq!(stats.dimension, 3);
159        assert_eq!(stats.distance_metric, DistanceMetric::Euclidean);
160    }
161
162    #[test]
163    fn test_vector_insertion() {
164        let config = hnsw_config().dimension(3).build().unwrap();
165        let mut hnsw = HnswIndex::new("test_insert", config).unwrap();
166        let vector = vec![1.0, 0.0, 0.0];
167        let metadata = serde_json::json!({"label": "test"});
168
169        let result = hnsw.insert_vector(&vector, Some(metadata));
170        println!("Insert result: {:?}", result);
171        let vector_id = result.unwrap();
172        assert!(vector_id > 0);
173
174        let stats = hnsw.statistics().unwrap();
175        assert_eq!(stats.vector_count, 1);
176    }
177
178    #[test]
179    fn test_dimension_mismatch_error() {
180        let mut hnsw = HnswIndex::new("test_dim_error", HnswConfig::default()).unwrap();
181        let wrong_vector = vec![1.0, 0.0]; // Default config expects 768 dimensions
182
183        let result = hnsw.insert_vector(&wrong_vector, None);
184        assert!(result.is_err());
185
186        let error = result.unwrap_err();
187        assert!(matches!(
188            error,
189            HnswError::Index(HnswIndexError::VectorDimensionMismatch { .. })
190        ));
191    }
192
193    #[test]
194    fn test_empty_search() {
195        let hnsw = HnswIndex::new("test_empty_search", HnswConfig::default()).unwrap();
196        let query = vec![1.0; 768];
197
198        let results = hnsw.search(&query, 5).unwrap();
199        assert!(results.is_empty());
200    }
201
202    #[test]
203    fn test_vector_retrieval() {
204        let config = hnsw_config().dimension(3).build().unwrap();
205        let mut hnsw = HnswIndex::new("test_retrieval", config).unwrap();
206        let vector = vec![1.0, 0.0, 0.0];
207        let metadata = serde_json::json!({"label": "test"});
208
209        let vector_id = hnsw.insert_vector(&vector, Some(metadata.clone())).unwrap();
210        let result = hnsw.get_vector(vector_id).unwrap();
211
212        assert!(result.is_some());
213        let (retrieved_vector, retrieved_metadata) = result.unwrap();
214        assert_eq!(retrieved_vector, vector);
215        assert_eq!(retrieved_metadata, metadata);
216    }
217
218    #[test]
219    fn test_sqlite_graph_integration() {
220        let graph = SqliteGraph::open_in_memory().unwrap();
221        let config = HnswConfigBuilder::new()
222            .dimension(4)
223            .distance_metric(DistanceMetric::Cosine)
224            .build()
225            .unwrap();
226
227        let mut hnsw_indexes = graph.hnsw_index("test_index", config).unwrap();
228        let hnsw = hnsw_indexes.get("test_index").unwrap();
229        let stats = hnsw.statistics().unwrap();
230
231        assert_eq!(stats.vector_count, 0);
232        assert_eq!(stats.dimension, 4);
233        assert_eq!(stats.distance_metric, DistanceMetric::Cosine);
234    }
235
236    #[test]
237    fn test_basic_search_functionality() {
238        let mut hnsw = HnswIndex::new(
239            "test_search",
240            HnswConfigBuilder::new()
241                .dimension(2)
242                .m_connections(4)
243                .distance_metric(DistanceMetric::Euclidean)
244                .build()
245                .unwrap(),
246        )
247        .unwrap();
248
249        // Insert some test vectors
250        let vectors = vec![
251            vec![1.0, 0.0],
252            vec![0.0, 1.0],
253            vec![-1.0, 0.0],
254            vec![0.0, -1.0],
255        ];
256
257        let mut vector_ids = Vec::new();
258        for vector in vectors {
259            let id = hnsw.insert_vector(&vector, None).unwrap();
260            vector_ids.push(id);
261        }
262
263        // Search for nearest neighbors
264        let query = vec![0.9, 0.1];
265        let results = hnsw.search(&query, 2).unwrap();
266
267        assert!(!results.is_empty());
268        assert!(results.len() <= 2);
269
270        // Results should be sorted by distance
271        for window in results.windows(2) {
272            assert!(window[0].1 <= window[1].1);
273        }
274    }
275
276    #[test]
277    fn test_index_statistics() {
278        let mut hnsw = HnswIndex::new(
279            "test_stats",
280            HnswConfigBuilder::new()
281                .dimension(3)
282                .max_layers(3)
283                .distance_metric(DistanceMetric::Euclidean) // Use Euclidean to avoid zero magnitude issues
284                .build()
285                .unwrap(),
286        )
287        .unwrap();
288
289        // Insert some vectors (starting from 1 to avoid all-zero vector)
290        for i in 1..=5 {
291            let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
292            hnsw.insert_vector(&vector, None).unwrap();
293        }
294
295        let stats = hnsw.statistics().unwrap();
296        assert_eq!(stats.vector_count, 5);
297        assert_eq!(stats.layer_count, 3);
298        assert_eq!(stats.dimension, 3);
299        assert!(!stats.layer_stats.is_empty());
300    }
301
302    #[test]
303    fn test_metadata_persistence() {
304        use rusqlite::Connection;
305        use std::fs;
306
307        let test_dir = "/tmp/test_hnsw_metadata_persistence";
308        let db_path = format!("{}/test.db", test_dir);
309
310        // Clean up any existing test database
311        let _ = fs::remove_dir_all(test_dir);
312
313        // Create directory
314        fs::create_dir_all(test_dir).unwrap();
315
316        // Create graph and index
317        {
318            let graph = SqliteGraph::open(&db_path).unwrap();
319            let config = HnswConfigBuilder::new()
320                .dimension(128)
321                .distance_metric(DistanceMetric::Euclidean)
322                .build()
323                .unwrap();
324
325            let mut hnsw_indexes = graph.hnsw_index("persist_test", config).unwrap();
326            let hnsw = hnsw_indexes.get("persist_test").unwrap();
327
328            // Verify index was created
329            assert_eq!(hnsw.name(), "persist_test");
330            assert_eq!(hnsw.config().dimension, 128);
331            assert_eq!(hnsw.config().distance_metric, DistanceMetric::Euclidean);
332
333            // Save metadata explicitly
334            let conn = graph.connection();
335            let conn_ref = conn.underlying();
336            hnsw.save_metadata(conn_ref).unwrap();
337        }
338
339        // Reopen and verify metadata persists
340        {
341            let graph2 = SqliteGraph::open(&db_path).unwrap();
342
343            // Check that index was loaded
344            let index_names = graph2.list_hnsw_indexes().unwrap();
345            assert_eq!(index_names, vec!["persist_test".to_string()]);
346
347            // Get the loaded index
348            let loaded_hnsw = graph2
349                .get_hnsw_index_ref("persist_test", |hnsw| {
350                    assert_eq!(hnsw.name(), "persist_test");
351                    assert_eq!(hnsw.config().dimension, 128);
352                    assert_eq!(hnsw.config().distance_metric, DistanceMetric::Euclidean);
353                    hnsw.config().dimension
354                })
355                .unwrap();
356
357            assert_eq!(loaded_hnsw, 128);
358        }
359
360        // Clean up
361        let _ = fs::remove_dir_all(test_dir);
362    }
363
364    #[test]
365    fn test_vector_loading_and_rebuild() {
366        use rusqlite::Connection;
367        use std::fs;
368
369        let test_dir = "/tmp/test_hnsw_vector_loading";
370        let db_path = format!("{}/test.db", test_dir);
371
372        // Clean up any existing test database
373        let _ = fs::remove_dir_all(test_dir);
374
375        // Create directory
376        fs::create_dir_all(test_dir).unwrap();
377
378        // Create index and manually persist vectors to database
379        {
380            let conn = Connection::open(&db_path).unwrap();
381
382            // Create schema
383            crate::schema::ensure_schema(&conn).unwrap();
384
385            // Create HNSW index metadata
386            conn.execute(
387                "INSERT INTO hnsw_indexes (name, dimension, m, ef_construction, distance_metric, vector_count, created_at, updated_at)
388                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
389                rusqlite::params!["load_test", 3, 16, 200, "euclidean", 5, 1000, 1000],
390            ).unwrap();
391
392            let index_id = conn.last_insert_rowid();
393
394            // Manually insert vectors into database
395            for i in 0..5 {
396                let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
397                let vector_bytes = bytemuck::cast_slice::<f32, u8>(&vector).to_vec();
398
399                conn.execute(
400                    "INSERT INTO hnsw_vectors (index_id, vector_data, metadata, created_at, updated_at)
401                     VALUES (?1, ?2, ?3, ?4, ?5)",
402                    rusqlite::params![index_id, vector_bytes, None::<String>, 1000, 1000],
403                ).unwrap();
404            }
405        }
406
407        // Load index with vectors and verify rebuild works
408        {
409            let conn2 = Connection::open(&db_path).unwrap();
410            crate::schema::ensure_schema(&conn2).unwrap();
411
412            // Load metadata only (vectors not loaded yet)
413            let hnsw_metadata = HnswIndex::load_metadata(&conn2, "load_test").unwrap();
414            assert_eq!(hnsw_metadata.vector_count, 5);
415            assert_eq!(hnsw_metadata.storage.vector_count().unwrap(), 0); // No vectors loaded
416
417            // Load with vectors - this rebuilds the graph
418            let mut hnsw_loaded = HnswIndex::load_with_vectors(&conn2, "load_test").unwrap();
419            assert_eq!(hnsw_loaded.vector_count, 5);
420            assert_eq!(hnsw_loaded.storage.vector_count().unwrap(), 5); // Vectors loaded
421
422            // Verify we can retrieve vectors
423            let (vector, _) = hnsw_loaded.get_vector(1).unwrap().unwrap();
424            assert_eq!(vector, vec![0.0, 0.0, 0.0]);
425
426            // Verify search works (graph was rebuilt)
427            let query = vec![2.0, 4.0, 6.0];
428            let results = hnsw_loaded.search(&query, 3).unwrap();
429            assert!(!results.is_empty());
430        }
431
432        // Clean up
433        let _ = fs::remove_dir_all(test_dir);
434    }
435
436    #[test]
437    fn test_e2e_hnsw_persistence() {
438        use rusqlite::Connection;
439        use std::fs;
440
441        let test_dir = "/tmp/test_hnsw_e2e_persistence";
442        let db_path = format!("{}/test.db", test_dir);
443
444        // Clean up any existing test database
445        let _ = fs::remove_dir_all(test_dir);
446
447        // Create directory
448        fs::create_dir_all(test_dir).unwrap();
449
450        // Create index and manually persist vectors to database
451        {
452            let conn = Connection::open(&db_path).unwrap();
453
454            // Create schema
455            crate::schema::ensure_schema(&conn).unwrap();
456
457            // Create HNSW index metadata
458            conn.execute(
459                "INSERT INTO hnsw_indexes (name, dimension, m, ef_construction, distance_metric, vector_count, created_at, updated_at)
460                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
461                rusqlite::params!["e2e_test", 3, 16, 200, "euclidean", 5, 1000, 1000],
462            ).unwrap();
463
464            let index_id = conn.last_insert_rowid();
465
466            // Manually insert vectors into database (simulating what SQLiteVectorStorage would do)
467            for i in 0..5 {
468                let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
469                let vector_bytes = bytemuck::cast_slice::<f32, u8>(&vector).to_vec();
470                let metadata = serde_json::json!({"label": format!("vector_{}", i)}).to_string();
471
472                conn.execute(
473                    "INSERT INTO hnsw_vectors (index_id, vector_data, metadata, created_at, updated_at)
474                     VALUES (?1, ?2, ?3, ?4, ?5)",
475                    rusqlite::params![index_id, vector_bytes, metadata, 1000, 1000],
476                ).unwrap();
477            }
478        }
479
480        // Reopen database and verify index is restored with vectors via SqliteGraph
481        {
482            let graph = SqliteGraph::open(&db_path).unwrap();
483
484            // Check that index was loaded
485            let index_names = graph.list_hnsw_indexes().unwrap();
486            assert_eq!(index_names, vec!["e2e_test".to_string()]);
487
488            // Get the loaded index
489            let loaded_count = graph
490                .get_hnsw_index_ref("e2e_test", |hnsw| {
491                    // Verify all vectors were loaded
492                    assert_eq!(hnsw.vector_count(), 5);
493
494                    // Verify we can retrieve a vector
495                    let (vector, metadata) = hnsw.get_vector(1).unwrap().unwrap();
496                    assert_eq!(vector, vec![0.0, 0.0, 0.0]);
497                    assert_eq!(metadata, serde_json::json!({"label": "vector_0"}));
498
499                    // Verify search works (graph was rebuilt)
500                    let query = vec![2.0, 4.0, 6.0];
501                    let results = hnsw.search(&query, 3).unwrap();
502                    assert!(!results.is_empty());
503
504                    hnsw.vector_count()
505                })
506                .unwrap();
507
508            assert_eq!(loaded_count, 5);
509        }
510
511        // Clean up
512        let _ = fs::remove_dir_all(test_dir);
513    }
514
515    #[test]
516    fn test_multilayer_level_distribution() {
517        // Create HnswIndex with multi-layer enabled
518        let config = HnswConfig {
519            dimension: 4,
520            m: 16,
521            ef_construction: 200,
522            ef_search: 50,
523            ml: 4,
524            distance_metric: DistanceMetric::Euclidean,
525            enable_multilayer: true,
526            multilayer_level_distribution_base: Some(16),
527            multilayer_deterministic_seed: Some(42),
528        };
529
530        let mut hnsw = HnswIndex::new("test_multilayer_dist", config).unwrap();
531
532        // Verify level distributor was initialized
533        assert!(
534            hnsw.has_level_distributor(),
535            "LevelDistributor should be initialized in multi-layer mode"
536        );
537
538        // Sample 1000 levels directly from the distributor to verify distribution
539        use crate::hnsw::multilayer::LevelDistributor;
540        let mut distributor = LevelDistributor::new(16.0, 4).with_seed(42);
541
542        let mut level_counts = vec![0; 4];
543        for _ in 0..1000 {
544            let level = distributor.sample_level_internal();
545            level_counts[level] += 1;
546        }
547
548        // The distribution is:
549        // - P(level = 0) = 1 - 1/16 = 15/16 ≈ 937.5 out of 1000 (only base layer)
550        // - P(level = 1) = 1/16 - 1/256 ≈ 58.6 out of 1000 (layers 0, 1)
551        // - P(level = 2) = 1/256 - 1/4096 ≈ 3.7 out of 1000 (layers 0, 1, 2)
552        // - P(level = 3) = 1/4096 ≈ 0.24 out of 1000 (layers 0, 1, 2, 3)
553
554        // Level 0 should have approximately 937-944 vectors (allow 900-950 range)
555        assert!(
556            level_counts[0] >= 900 && level_counts[0] <= 950,
557            "Level 0 should have ~938 samples, got {}",
558            level_counts[0]
559        );
560
561        // Level 1 should have approximately 1000/16 = ~62 samples (allow 40-80 range)
562        assert!(
563            level_counts[1] >= 40 && level_counts[1] <= 80,
564            "Level 1 should have ~62 samples, got {}",
565            level_counts[1]
566        );
567
568        // Level 2 should have approximately 1000/256 = ~4 samples (allow 1-10 range)
569        assert!(
570            level_counts[2] >= 1 && level_counts[2] <= 10,
571            "Level 2 should have ~4 samples, got {}",
572            level_counts[2]
573        );
574
575        println!(
576            "Level distribution (direct sampling): L0={}, L1={}, L2={}, L3={}",
577            level_counts[0], level_counts[1], level_counts[2], level_counts[3]
578        );
579
580        // Note: Full multi-layer graph insertion requires LayerMappings integration
581        // (deferred to plan 15-02) to handle bidirectional ID translation between
582        // global vector IDs and layer-local node IDs.
583        //
584        // For now, the exponential distribution is wired into determine_insertion_level()
585        // and will produce the correct level assignments. The full multi-layer graph
586        // structure will be completed in subsequent plans.
587    }
588
589    #[test]
590    fn test_single_layer_mode() {
591        // Create HnswIndex with single-layer mode (default)
592        let config = HnswConfig {
593            dimension: 4,
594            m: 16,
595            ef_construction: 200,
596            ef_search: 50,
597            ml: 4,
598            distance_metric: DistanceMetric::Euclidean,
599            enable_multilayer: false, // Single-layer mode
600            multilayer_level_distribution_base: None,
601            multilayer_deterministic_seed: None,
602        };
603
604        let hnsw = HnswIndex::new("test_single_layer", config.clone()).unwrap();
605
606        // Verify level distributor is NOT initialized in single-layer mode
607        assert!(
608            !hnsw.has_level_distributor(),
609            "LevelDistributor should not be initialized in single-layer mode"
610        );
611
612        // Insert 100 vectors
613        let test_vector = vec![1.0, 0.0, 0.0, 0.0];
614        let mut hnsw_mut = HnswIndex::new("test_single_layer_mut", config).unwrap();
615        for _ in 0..100 {
616            hnsw_mut.insert_vector(&test_vector, None).unwrap();
617        }
618
619        let stats = hnsw_mut.statistics().unwrap();
620
621        // In single-layer mode, all vectors should only be in layer 0
622        assert_eq!(
623            stats.layer_stats[0].0, 100,
624            "Layer 0 should have 100 vectors"
625        );
626
627        // Higher layers should be empty
628        assert_eq!(
629            stats.layer_stats[1].0, 0,
630            "Layer 1 should be empty in single-layer mode"
631        );
632        assert_eq!(
633            stats.layer_stats[2].0, 0,
634            "Layer 2 should be empty in single-layer mode"
635        );
636        assert_eq!(
637            stats.layer_stats[3].0, 0,
638            "Layer 3 should be empty in single-layer mode"
639        );
640    }
641
642    #[test]
643    fn test_multilayer_recall() {
644        use std::collections::HashSet;
645
646        let config = HnswConfig {
647            dimension: 64,
648            m: 16,
649            ef_construction: 200,
650            ef_search: 50,
651            ml: 16,
652            distance_metric: DistanceMetric::Euclidean,
653            enable_multilayer: true, // Test multi-layer recall
654            multilayer_level_distribution_base: Some(16),
655            multilayer_deterministic_seed: Some(42),
656        };
657
658        let mut hnsw = HnswIndex::new("recall_test_unique", config).unwrap();
659        let mut vectors = Vec::new();
660
661        // Insert 100 random vectors
662        for i in 0..1000 {
663            let vector: Vec<f32> = (0..64)
664                .map(|j| ((i * 64 + j) as f32 * 0.01).cos())
665                .collect();
666            vectors.push(vector.clone());
667            hnsw.insert_vector(&vector, None).unwrap();
668        }
669
670        let k = 10;
671        let query = &vectors[0];
672
673        // HNSW approximate results
674        let hnsw_results = hnsw.search(query, k).unwrap();
675        let hnsw_ids: HashSet<_> = hnsw_results.iter().map(|(id, _)| *id).collect();
676
677        // Exact nearest neighbors (brute force)
678        fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
679            a.iter()
680                .zip(b.iter())
681                .map(|(x, y)| (x - y).powi(2))
682                .sum::<f32>()
683                .sqrt()
684        }
685
686        let mut exact_results: Vec<_> = vectors
687            .iter()
688            .enumerate()
689            .map(|(i, v)| (i as u64 + 1, euclidean_distance(query, v)))
690            .collect();
691
692        // Sort by distance (simple manual sort)
693        for i in 0..exact_results.len() {
694            let mut min_idx = i;
695            for j in (i + 1)..exact_results.len() {
696                if exact_results[j].1 < exact_results[min_idx].1 {
697                    min_idx = j;
698                }
699            }
700            if min_idx != i {
701                let temp = exact_results[i];
702                exact_results[i] = exact_results[min_idx];
703                exact_results[min_idx] = temp;
704            }
705        }
706
707        let exact_ids: HashSet<_> = exact_results.iter().take(k).map(|(id, _)| *id).collect();
708
709        // Count overlap
710        let overlap = hnsw_ids.intersection(&exact_ids).count();
711        let recall = (overlap as f64 / k as f64) * 100.0;
712
713        println!("HNSW results: {:?}", hnsw_results);
714        println!("Exact top {}: {:?}", k, exact_ids);
715        println!("Recall: {:.1}% ({}/{})", recall, overlap, k);
716        assert!(
717            recall >= 90.0,
718            "Recall {:.1}% is below 90% threshold",
719            recall
720        );
721    }
722
723    #[test]
724    fn test_multilayer_search_complexity_ologn() {
725        use std::time::Instant;
726
727        // Test configurations with increasing dataset sizes
728        let sizes = vec![100, 1000, 10000];
729        let mut search_times = Vec::new();
730
731        for size in sizes {
732            let config = HnswConfig {
733                dimension: 64,
734                m: 16,
735                ef_construction: 200,
736                ef_search: 50,
737                ml: 16,
738                distance_metric: DistanceMetric::Euclidean,
739                enable_multilayer: true,
740                multilayer_level_distribution_base: Some(16),
741                multilayer_deterministic_seed: Some(42),
742            };
743
744            let mut hnsw = HnswIndex::new(&format!("complexity_test_{}", size), config).unwrap();
745
746            // Insert vectors
747            for i in 0..size {
748                let vector: Vec<f32> = (0..64)
749                    .map(|j| ((i * 64 + j) as f32 * 0.01).sin())
750                    .collect();
751                hnsw.insert_vector(&vector, None).unwrap();
752            }
753
754            // Measure search time (average of multiple searches)
755            let query: Vec<f32> = (0..64).map(|j| (j as f32 * 0.01).sin()).collect();
756            let iterations = 10;
757            let start = Instant::now();
758            for _ in 0..iterations {
759                let _ = hnsw.search(&query, 10).unwrap();
760            }
761            let elapsed = start.elapsed();
762            let avg_time_ns = elapsed.as_nanos() / iterations as u128;
763            search_times.push((size, avg_time_ns));
764
765            println!("Size {}: avg search time = {} ns", size, avg_time_ns);
766        }
767
768        // Verify logarithmic scaling: T(1000) / T(100) should be < 10
769        // Linear scaling would be 10x (1000/100), logarithmic is typically < 5x
770        let ratio_100_to_1000 = search_times[1].1 as f64 / search_times[0].1 as f64;
771        println!("Time ratio (1000/100): {:.2}x", ratio_100_to_1000);
772        assert!(
773            ratio_100_to_1000 < 10.0,
774            "Search time ratio {:.2}x suggests worse than log scaling; expected < 10x for O(log N)",
775            ratio_100_to_1000
776        );
777
778        // Verify logarithmic scaling: T(10000) / T(1000) should be < 10
779        // Linear scaling would be 10x (10000/1000), but log should be better
780        let ratio_1000_to_10000 = search_times[2].1 as f64 / search_times[1].1 as f64;
781        println!("Time ratio (10000/1000): {:.2}x", ratio_1000_to_10000);
782        assert!(
783            ratio_1000_to_10000 < 10.0,
784            "Search time ratio {:.2}x suggests worse than log scaling; expected < 10x for O(log N)",
785            ratio_1000_to_10000
786        );
787
788        // Most importantly: overall T(10000) / T(100) should be MUCH better than linear (100x)
789        let overall_ratio = search_times[2].1 as f64 / search_times[0].1 as f64;
790        println!("Overall time ratio (10000/100): {:.2}x", overall_ratio);
791        assert!(
792            overall_ratio < 50.0,
793            "Overall search time ratio {:.2}x suggests linear scaling; expected < 50x for O(log N) (linear would be 100x)",
794            overall_ratio
795        );
796    }
797
798    #[test]
799    fn test_multilayer_insert_layers_correct() {
800        let config = HnswConfig {
801            dimension: 64,
802            m: 16,
803            ef_construction: 200,
804            ef_search: 50,
805            ml: 16,
806            distance_metric: DistanceMetric::Euclidean,
807            enable_multilayer: true,
808            multilayer_level_distribution_base: Some(16),
809            multilayer_deterministic_seed: Some(42),
810        };
811
812        let mut hnsw = HnswIndex::new("test_layers", config).unwrap();
813
814        // Insert 100 vectors
815        for i in 0..100 {
816            let vector: Vec<f32> = (0..64)
817                .map(|j| ((i * 64 + j) as f32 * 0.01).cos())
818                .collect();
819            hnsw.insert_vector(&vector, None).unwrap();
820        }
821
822        // Verify nodes are distributed across layers
823        let stats = hnsw.statistics().unwrap();
824
825        println!("Layer stats: {:?}", stats.layer_stats);
826
827        // All 100 vectors should be in layer 0 (base layer)
828        assert_eq!(
829            stats.layer_stats[0].0, 100,
830            "Layer 0 should have all 100 vectors"
831        );
832
833        // Layer 1 should have some vectors (approximately 100/16 = 6-7)
834        // With seed 42 and exponential distribution, we expect ~6 vectors in layer 1
835        let layer1_count = stats.layer_stats[1].0;
836        assert!(
837            layer1_count > 0 && layer1_count < 20,
838            "Layer 1 should have some vectors (got {}), but not all",
839            layer1_count
840        );
841
842        // Verify higher layers have fewer or equal nodes than lower layers
843        assert!(
844            stats.layer_stats[0].0 >= stats.layer_stats[1].0,
845            "Layer 0 should have >= Layer 1"
846        );
847        assert!(
848            stats.layer_stats[1].0 >= stats.layer_stats[2].0,
849            "Layer 1 should have >= Layer 2"
850        );
851
852        // Verify multi-layer mode is enabled
853        assert!(
854            hnsw.has_level_distributor(),
855            "LevelDistributor should be initialized"
856        );
857    }
858}