Skip to main content

shodh_memory/vector_db/
spann.rs

1//! SPANN - Scalable Proximity-graph ANN for billion-scale vector search
2//!
3//! Implements disk-based IVF (Inverted File Index) with optional PQ compression.
4//! Designed for datasets that don't fit in RAM.
5//!
6//! # Architecture
7//!
8//! SPANN partitions vectors into √n clusters using k-means:
9//! - **Centroids**: Cluster centers stored in RAM for fast routing
10//! - **Posting Lists**: Vector assignments stored on disk, mmap'd on demand
11//! - **PQ Compression**: Optional product quantization for 32x storage reduction
12//!
13//! # File Format (v1)
14//!
15//! ```text
16//! ┌─────────────────────────────────────────────────────────┐
17//! │ Header (128 bytes)                                      │
18//! │ ├── magic: [u8; 4] = "SPAN"                             │
19//! │ ├── version: u32 = 1                                    │
20//! │ ├── num_vectors: u64                                    │
21//! │ ├── num_partitions: u32                                 │
22//! │ ├── dimension: u32                                      │
23//! │ ├── pq_enabled: u8 (0 or 1)                             │
24//! │ ├── pq_subvectors: u32                                  │
25//! │ ├── distance_metric: u8                                 │
26//! │ ├── checksum: u64                                       │
27//! │ ├── centroids_offset: u64                               │
28//! │ ├── codebook_offset: u64                                │
29//! │ ├── posting_index_offset: u64                           │
30//! │ ├── posting_data_offset: u64                            │
31//! │ └── reserved: [u8; 60]                                  │
32//! ├─────────────────────────────────────────────────────────┤
33//! │ Centroids Section (aligned to 64 bytes)                 │
34//! │ └── [[f32; dimension]; num_partitions]                  │
35//! ├─────────────────────────────────────────────────────────┤
36//! │ PQ Codebook Section (if pq_enabled, aligned to 64)      │
37//! │ ├── num_subvectors: u32                                 │
38//! │ ├── num_centroids: u32 (always 256)                     │
39//! │ ├── subvec_dim: u32                                     │
40//! │ └── codebook: [[f32; subvec_dim]; 256] × num_subvectors │
41//! ├─────────────────────────────────────────────────────────┤
42//! │ Posting List Index (12 bytes per partition)             │
43//! │ └── [(offset: u64, count: u32); num_partitions]         │
44//! ├─────────────────────────────────────────────────────────┤
45//! │ Posting List Data                                       │
46//! │ └── For each partition:                                 │
47//! │     └── entries: [PostingEntry; count]                  │
48//! │         where PostingEntry =                            │
49//! │           - vector_id: u32                              │
50//! │           - pq_codes: [u8; num_subvectors] (if PQ)      │
51//! └─────────────────────────────────────────────────────────┘
52//! ```
53//!
54//! # Query Flow
55//!
56//! 1. Compute distances from query to all centroids
57//! 2. Select top-k nearest partitions (multi-probe)
58//! 3. Load posting lists for selected partitions
59//! 4. Compute PQ distances (ADC) or exact distances
60//! 5. Return top-k results across all probed partitions
61
62use anyhow::{anyhow, Result};
63use memmap2::{Mmap, MmapMut};
64use parking_lot::RwLock;
65use rand::seq::SliceRandom;
66use std::collections::BinaryHeap;
67use std::fs::{File, OpenOptions};
68use std::path::Path;
69use std::sync::atomic::{AtomicUsize, Ordering};
70use std::sync::Arc;
71use tracing::info;
72
73use super::pq::{PQConfig, ProductQuantizer, NUM_CENTROIDS};
74use super::vamana::DistanceMetric;
75
76const MAGIC: [u8; 4] = *b"SPAN";
77const VERSION: u32 = 1;
78const HEADER_SIZE: usize = 128;
79const ALIGNMENT: usize = 64;
80const POSTING_INDEX_ENTRY_SIZE: usize = 12; // u64 offset + u32 count
81
82/// SPANN configuration
83#[derive(Debug, Clone)]
84pub struct SpannConfig {
85    /// Vector dimension
86    pub dimension: usize,
87    /// Number of partitions (default: √n, set during build)
88    pub num_partitions: Option<usize>,
89    /// Enable PQ compression (32x storage reduction, ~5% recall loss)
90    pub use_pq: bool,
91    /// Number of partitions to probe during search (default: 10)
92    pub num_probes: usize,
93    /// K-means iterations for clustering
94    pub kmeans_iterations: usize,
95    /// Distance metric
96    pub distance_metric: DistanceMetric,
97    /// Minimum vectors per partition before merge
98    pub min_partition_size: usize,
99    /// Maximum vectors per partition before split
100    pub max_partition_size: usize,
101}
102
103impl Default for SpannConfig {
104    fn default() -> Self {
105        Self {
106            dimension: 384,
107            num_partitions: None, // Auto-compute as √n
108            use_pq: true,
109            num_probes: 10,
110            kmeans_iterations: 25,
111            distance_metric: DistanceMetric::NormalizedDotProduct,
112            min_partition_size: 100,
113            max_partition_size: 10000,
114        }
115    }
116}
117
118impl SpannConfig {
119    /// Create config for MiniLM embeddings (384 dims)
120    pub fn minilm() -> Self {
121        Self {
122            dimension: 384,
123            ..Default::default()
124        }
125    }
126
127    /// Create config for CLIP embeddings (768 dims)
128    pub fn clip() -> Self {
129        Self {
130            dimension: 768,
131            ..Default::default()
132        }
133    }
134
135    /// Compute optimal partition count for dataset size
136    pub fn compute_partitions(&self, num_vectors: usize) -> usize {
137        self.num_partitions
138            .unwrap_or_else(|| ((num_vectors as f64).sqrt().ceil() as usize).max(1))
139    }
140}
141
142/// Posting list entry - vector ID with optional PQ codes
143#[derive(Debug, Clone)]
144pub struct PostingEntry {
145    /// Vector ID (maps back to original storage)
146    pub vector_id: u32,
147    /// PQ codes (if PQ enabled)
148    pub pq_codes: Option<Vec<u8>>,
149}
150
151impl PostingEntry {
152    /// Size in bytes when serialized
153    pub fn serialized_size(pq_subvectors: usize) -> usize {
154        4 + pq_subvectors // u32 + pq_codes
155    }
156}
157
158/// Partition metadata
159#[derive(Debug, Clone)]
160pub struct Partition {
161    /// Partition ID
162    pub id: u32,
163    /// Centroid vector
164    pub centroid: Vec<f32>,
165    /// Entries in this partition
166    pub entries: Vec<PostingEntry>,
167}
168
169/// File header for SPANN index
170#[repr(C, packed)]
171#[derive(Debug, Clone, Copy)]
172struct SpannHeader {
173    magic: [u8; 4],
174    version: u32,
175    num_vectors: u64,
176    num_partitions: u32,
177    dimension: u32,
178    pq_enabled: u8,
179    pq_subvectors: u32,
180    distance_metric: u8,
181    checksum: u64,
182    centroids_offset: u64,
183    codebook_offset: u64,
184    posting_index_offset: u64,
185    posting_data_offset: u64,
186    reserved: [u8; 60],
187}
188
189impl SpannHeader {
190    fn new(
191        num_vectors: usize,
192        num_partitions: usize,
193        dimension: usize,
194        pq_enabled: bool,
195        pq_subvectors: usize,
196        distance_metric: DistanceMetric,
197    ) -> Self {
198        Self {
199            magic: MAGIC,
200            version: VERSION,
201            num_vectors: num_vectors as u64,
202            num_partitions: num_partitions as u32,
203            dimension: dimension as u32,
204            pq_enabled: if pq_enabled { 1 } else { 0 },
205            pq_subvectors: pq_subvectors as u32,
206            distance_metric: match distance_metric {
207                DistanceMetric::NormalizedDotProduct => 0,
208                DistanceMetric::Euclidean => 1,
209                DistanceMetric::Cosine => 2,
210            },
211            checksum: 0,
212            centroids_offset: 0,
213            codebook_offset: 0,
214            posting_index_offset: 0,
215            posting_data_offset: 0,
216            reserved: [0u8; 60],
217        }
218    }
219
220    fn to_bytes(&self) -> [u8; HEADER_SIZE] {
221        let mut bytes = [0u8; HEADER_SIZE];
222        let mut offset = 0;
223
224        bytes[offset..offset + 4].copy_from_slice(&self.magic);
225        offset += 4;
226        bytes[offset..offset + 4].copy_from_slice(&self.version.to_le_bytes());
227        offset += 4;
228        bytes[offset..offset + 8].copy_from_slice(&self.num_vectors.to_le_bytes());
229        offset += 8;
230        bytes[offset..offset + 4].copy_from_slice(&self.num_partitions.to_le_bytes());
231        offset += 4;
232        bytes[offset..offset + 4].copy_from_slice(&self.dimension.to_le_bytes());
233        offset += 4;
234        bytes[offset] = self.pq_enabled;
235        offset += 1;
236        bytes[offset..offset + 4].copy_from_slice(&self.pq_subvectors.to_le_bytes());
237        offset += 4;
238        bytes[offset] = self.distance_metric;
239        offset += 1;
240        bytes[offset..offset + 8].copy_from_slice(&self.checksum.to_le_bytes());
241        offset += 8;
242        bytes[offset..offset + 8].copy_from_slice(&self.centroids_offset.to_le_bytes());
243        offset += 8;
244        bytes[offset..offset + 8].copy_from_slice(&self.codebook_offset.to_le_bytes());
245        offset += 8;
246        bytes[offset..offset + 8].copy_from_slice(&self.posting_index_offset.to_le_bytes());
247        offset += 8;
248        bytes[offset..offset + 8].copy_from_slice(&self.posting_data_offset.to_le_bytes());
249        // Reserved bytes already 0
250        bytes
251    }
252
253    fn from_bytes(bytes: &[u8]) -> Result<Self> {
254        if bytes.len() < HEADER_SIZE {
255            return Err(anyhow!("Header too small"));
256        }
257
258        let magic: [u8; 4] = bytes[0..4].try_into()?;
259        if magic != MAGIC {
260            return Err(anyhow!("Invalid magic bytes: {:?}", magic));
261        }
262
263        let version = u32::from_le_bytes(bytes[4..8].try_into()?);
264        if version != VERSION {
265            return Err(anyhow!("Unsupported version: {}", version));
266        }
267
268        let mut offset = 8;
269        let num_vectors = u64::from_le_bytes(bytes[offset..offset + 8].try_into()?);
270        offset += 8;
271        let num_partitions = u32::from_le_bytes(bytes[offset..offset + 4].try_into()?);
272        offset += 4;
273        let dimension = u32::from_le_bytes(bytes[offset..offset + 4].try_into()?);
274        offset += 4;
275        let pq_enabled = bytes[offset];
276        offset += 1;
277        let pq_subvectors = u32::from_le_bytes(bytes[offset..offset + 4].try_into()?);
278        offset += 4;
279        let distance_metric = bytes[offset];
280        offset += 1;
281        let checksum = u64::from_le_bytes(bytes[offset..offset + 8].try_into()?);
282        offset += 8;
283        let centroids_offset = u64::from_le_bytes(bytes[offset..offset + 8].try_into()?);
284        offset += 8;
285        let codebook_offset = u64::from_le_bytes(bytes[offset..offset + 8].try_into()?);
286        offset += 8;
287        let posting_index_offset = u64::from_le_bytes(bytes[offset..offset + 8].try_into()?);
288        offset += 8;
289        let posting_data_offset = u64::from_le_bytes(bytes[offset..offset + 8].try_into()?);
290
291        Ok(Self {
292            magic,
293            version,
294            num_vectors,
295            num_partitions,
296            dimension,
297            pq_enabled,
298            pq_subvectors,
299            distance_metric,
300            checksum,
301            centroids_offset,
302            codebook_offset,
303            posting_index_offset,
304            posting_data_offset,
305            reserved: [0u8; 60],
306        })
307    }
308
309    fn distance_metric_enum(&self) -> DistanceMetric {
310        match self.distance_metric {
311            0 => DistanceMetric::NormalizedDotProduct,
312            1 => DistanceMetric::Euclidean,
313            2 => DistanceMetric::Cosine,
314            _ => DistanceMetric::NormalizedDotProduct,
315        }
316    }
317}
318
319/// SPANN Index - Scalable disk-based ANN
320pub struct SpannIndex {
321    /// Configuration
322    pub config: SpannConfig,
323    /// Cluster centroids (kept in RAM for fast routing)
324    centroids: Arc<RwLock<Vec<Vec<f32>>>>,
325    /// PQ quantizer (if enabled)
326    quantizer: Arc<RwLock<Option<ProductQuantizer>>>,
327    /// Partitions (in-memory for building, cleared after save)
328    partitions: Arc<RwLock<Vec<Partition>>>,
329    /// Memory-mapped file for disk access
330    mmap: Arc<RwLock<Option<Mmap>>>,
331    /// Number of vectors indexed
332    num_vectors: AtomicUsize,
333    /// Number of partitions
334    num_partitions: AtomicUsize,
335    /// Posting index offsets (from header)
336    posting_index_offset: AtomicUsize,
337    /// Posting data offset (from header)
338    posting_data_offset: AtomicUsize,
339}
340
341impl SpannIndex {
342    /// Create a new empty SPANN index
343    pub fn new(config: SpannConfig) -> Self {
344        Self {
345            config,
346            centroids: Arc::new(RwLock::new(Vec::new())),
347            quantizer: Arc::new(RwLock::new(None)),
348            partitions: Arc::new(RwLock::new(Vec::new())),
349            mmap: Arc::new(RwLock::new(None)),
350            num_vectors: AtomicUsize::new(0),
351            num_partitions: AtomicUsize::new(0),
352            posting_index_offset: AtomicUsize::new(0),
353            posting_data_offset: AtomicUsize::new(0),
354        }
355    }
356
357    /// Build index from vectors
358    ///
359    /// 1. Cluster vectors using k-means
360    /// 2. Assign each vector to nearest centroid
361    /// 3. Optionally train PQ and encode vectors
362    pub fn build(&mut self, vectors: Vec<Vec<f32>>) -> Result<()> {
363        if vectors.is_empty() {
364            return Err(anyhow!("Cannot build index from empty vectors"));
365        }
366
367        let n = vectors.len();
368        let dim = vectors[0].len();
369
370        if dim != self.config.dimension {
371            return Err(anyhow!(
372                "Vector dimension {} doesn't match config {}",
373                dim,
374                self.config.dimension
375            ));
376        }
377
378        let num_partitions = self.config.compute_partitions(n);
379        info!(
380            "Building SPANN index: {} vectors, {} partitions, PQ={}",
381            n, num_partitions, self.config.use_pq
382        );
383
384        let start = std::time::Instant::now();
385
386        // Step 1: K-means clustering to find centroids
387        let centroids = self.kmeans_cluster(&vectors, num_partitions)?;
388
389        // Step 2: Train PQ on vectors (if enabled)
390        let quantizer = if self.config.use_pq {
391            info!("Training PQ quantizer...");
392            let pq_config = PQConfig::for_dimension(dim);
393            let pq = ProductQuantizer::train(pq_config, &vectors)?;
394            Some(pq)
395        } else {
396            None
397        };
398
399        // Step 3: Assign vectors to partitions
400        let mut partitions: Vec<Partition> = centroids
401            .iter()
402            .enumerate()
403            .map(|(i, c)| Partition {
404                id: i as u32,
405                centroid: c.clone(),
406                entries: Vec::new(),
407            })
408            .collect();
409
410        for (vec_id, vector) in vectors.iter().enumerate() {
411            let partition_id = self.find_nearest_centroid(vector, &centroids);
412
413            let pq_codes = if let Some(ref pq) = quantizer {
414                Some(pq.encode(vector)?)
415            } else {
416                None
417            };
418
419            partitions[partition_id].entries.push(PostingEntry {
420                vector_id: vec_id as u32,
421                pq_codes,
422            });
423        }
424
425        // Log partition distribution
426        let sizes: Vec<usize> = partitions.iter().map(|p| p.entries.len()).collect();
427        let min_size = sizes.iter().min().copied().unwrap_or(0);
428        let max_size = sizes.iter().max().copied().unwrap_or(0);
429        let avg_size = if !sizes.is_empty() {
430            sizes.iter().sum::<usize>() / sizes.len()
431        } else {
432            0
433        };
434        info!(
435            "Partition distribution: min={}, max={}, avg={}",
436            min_size, max_size, avg_size
437        );
438
439        // Store results
440        *self.centroids.write() = centroids;
441        *self.quantizer.write() = quantizer;
442        *self.partitions.write() = partitions;
443        self.num_vectors.store(n, Ordering::Release);
444        self.num_partitions.store(num_partitions, Ordering::Release);
445
446        info!("SPANN build complete in {:?}", start.elapsed());
447
448        Ok(())
449    }
450
451    /// K-means clustering
452    fn kmeans_cluster(&self, vectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>> {
453        let n = vectors.len();
454        let dim = vectors[0].len();
455        let iterations = self.config.kmeans_iterations;
456
457        // Initialize centroids by random sampling
458        let mut rng = rand::thread_rng();
459        let mut indices: Vec<usize> = (0..n).collect();
460        indices.shuffle(&mut rng);
461
462        let mut centroids: Vec<Vec<f32>> = indices
463            .iter()
464            .take(k)
465            .map(|&i| vectors[i].clone())
466            .collect();
467
468        // Pad if needed
469        while centroids.len() < k {
470            let idx = indices[centroids.len() % n];
471            centroids.push(vectors[idx].clone());
472        }
473
474        let mut assignments = vec![0usize; n];
475
476        // K-means iterations
477        for iter in 0..iterations {
478            let mut changed = 0usize;
479
480            // Assign vectors to nearest centroid
481            for (i, vec) in vectors.iter().enumerate() {
482                let new_assignment = self.find_nearest_centroid(vec, &centroids);
483                if new_assignment != assignments[i] {
484                    changed += 1;
485                }
486                assignments[i] = new_assignment;
487            }
488
489            // Update centroids
490            let mut new_centroids: Vec<Vec<f32>> = vec![vec![0.0; dim]; k];
491            let mut counts = vec![0usize; k];
492
493            for (i, vec) in vectors.iter().enumerate() {
494                let c = assignments[i];
495                counts[c] += 1;
496                for (j, &v) in vec.iter().enumerate() {
497                    new_centroids[c][j] += v;
498                }
499            }
500
501            for c in 0..k {
502                if counts[c] > 0 {
503                    for j in 0..dim {
504                        new_centroids[c][j] /= counts[c] as f32;
505                    }
506                    centroids[c] = new_centroids[c].clone();
507                }
508            }
509
510            if iter % 5 == 0 {
511                info!(
512                    "K-means iter {}/{}: {} assignments changed",
513                    iter + 1,
514                    iterations,
515                    changed
516                );
517            }
518
519            // Early termination if converged
520            if changed == 0 {
521                info!("K-means converged at iteration {}", iter + 1);
522                break;
523            }
524        }
525
526        Ok(centroids)
527    }
528
529    /// Find nearest centroid for a vector
530    #[inline]
531    fn find_nearest_centroid(&self, vector: &[f32], centroids: &[Vec<f32>]) -> usize {
532        let mut best_idx = 0;
533        let mut best_dist = f32::MAX;
534
535        for (i, centroid) in centroids.iter().enumerate() {
536            let dist = self.compute_distance(vector, centroid);
537            if dist < best_dist {
538                best_dist = dist;
539                best_idx = i;
540            }
541        }
542
543        best_idx
544    }
545
546    /// Compute distance between two vectors
547    #[inline]
548    fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
549        match self.config.distance_metric {
550            DistanceMetric::Euclidean => a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum(),
551            DistanceMetric::NormalizedDotProduct | DistanceMetric::Cosine => {
552                // For normalized vectors, 1 - dot_product gives distance
553                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
554                1.0 - dot
555            }
556        }
557    }
558
559    /// Search for k nearest neighbors
560    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>> {
561        let centroids = self.centroids.read();
562        if centroids.is_empty() {
563            return Ok(Vec::new());
564        }
565
566        // Step 1: Find top-nprobes nearest partitions
567        let mut partition_distances: Vec<(usize, f32)> = centroids
568            .iter()
569            .enumerate()
570            .map(|(i, c)| (i, self.compute_distance(query, c)))
571            .collect();
572
573        partition_distances.sort_by(|a, b| a.1.total_cmp(&b.1));
574        let probe_partitions: Vec<usize> = partition_distances
575            .iter()
576            .take(self.config.num_probes)
577            .map(|(i, _)| *i)
578            .collect();
579
580        // Step 2: Search in selected partitions
581        let quantizer = self.quantizer.read();
582        let partitions = self.partitions.read();
583
584        // Build distance table for PQ (required for SPANN search)
585        let distance_table = if let Some(ref pq) = *quantizer {
586            Some(pq.build_distance_table(query)?)
587        } else {
588            anyhow::bail!(
589                "SPANN search requires PQ quantizer but use_pq is disabled. \
590                 PostingEntry stores only PQ codes, not original vectors."
591            );
592        };
593
594        // Collect candidates from all probed partitions
595        // Use max-heap to keep k smallest distances: pop() removes largest (worst) match
596        let mut heap: BinaryHeap<(ordered_float::OrderedFloat<f32>, u32)> =
597            BinaryHeap::with_capacity(k);
598
599        if !partitions.is_empty() {
600            // In-memory search (before save or after full load)
601            for &partition_id in &probe_partitions {
602                if partition_id >= partitions.len() {
603                    continue;
604                }
605
606                for entry in &partitions[partition_id].entries {
607                    let dist = if let (Some(ref table), Some(ref codes)) =
608                        (&distance_table, &entry.pq_codes)
609                    {
610                        quantizer
611                            .as_ref()
612                            .unwrap()
613                            .distance_with_table(table, codes)
614                    } else {
615                        // Would need original vectors - skip for now
616                        continue;
617                    };
618
619                    heap.push((ordered_float::OrderedFloat(dist), entry.vector_id));
620                    if heap.len() > k {
621                        heap.pop(); // Removes largest distance (worst match)
622                    }
623                }
624            }
625        } else {
626            // Disk-based search (after load_from_file)
627            let mmap_guard = self.mmap.read();
628            if let Some(ref mmap) = *mmap_guard {
629                let pq_subvectors = if quantizer.is_some() {
630                    self.config.dimension / 8 // PQ subvector count
631                } else {
632                    0
633                };
634
635                for &partition_id in &probe_partitions {
636                    let entries = self.read_posting_list(mmap, partition_id, pq_subvectors)?;
637
638                    for entry in entries {
639                        let dist = if let (Some(ref table), Some(ref codes)) =
640                            (&distance_table, &entry.pq_codes)
641                        {
642                            quantizer
643                                .as_ref()
644                                .unwrap()
645                                .distance_with_table(table, codes)
646                        } else {
647                            continue;
648                        };
649
650                        heap.push((ordered_float::OrderedFloat(dist), entry.vector_id));
651                        if heap.len() > k {
652                            heap.pop(); // Removes largest distance (worst match)
653                        }
654                    }
655                }
656            }
657        }
658
659        // Convert heap to sorted results (smallest distance first)
660        let mut results: Vec<(u32, f32)> = heap.into_iter().map(|(d, id)| (id, d.0)).collect();
661        results.sort_by(|a, b| a.1.total_cmp(&b.1));
662
663        Ok(results)
664    }
665
666    /// Read posting list from mmap'd file
667    fn read_posting_list(
668        &self,
669        mmap: &Mmap,
670        partition_id: usize,
671        pq_subvectors: usize,
672    ) -> Result<Vec<PostingEntry>> {
673        let index_offset = self.posting_index_offset.load(Ordering::Acquire);
674        let data_offset = self.posting_data_offset.load(Ordering::Acquire);
675
676        // Read posting index entry
677        let entry_offset = index_offset + partition_id * POSTING_INDEX_ENTRY_SIZE;
678        if entry_offset + POSTING_INDEX_ENTRY_SIZE > mmap.len() {
679            return Err(anyhow!("Posting index out of bounds"));
680        }
681
682        let list_offset =
683            u64::from_le_bytes(mmap[entry_offset..entry_offset + 8].try_into()?) as usize;
684        let count =
685            u32::from_le_bytes(mmap[entry_offset + 8..entry_offset + 12].try_into()?) as usize;
686
687        // Read posting list entries
688        let entry_size = PostingEntry::serialized_size(pq_subvectors);
689        let list_start = data_offset + list_offset;
690        let list_end = list_start + count * entry_size;
691
692        if list_end > mmap.len() {
693            return Err(anyhow!("Posting list data out of bounds"));
694        }
695
696        let mut entries = Vec::with_capacity(count);
697        let mut offset = list_start;
698
699        for _ in 0..count {
700            let vector_id = u32::from_le_bytes(mmap[offset..offset + 4].try_into()?);
701            offset += 4;
702
703            let pq_codes = if pq_subvectors > 0 {
704                let codes = mmap[offset..offset + pq_subvectors].to_vec();
705                offset += pq_subvectors;
706                Some(codes)
707            } else {
708                None
709            };
710
711            entries.push(PostingEntry {
712                vector_id,
713                pq_codes,
714            });
715        }
716
717        Ok(entries)
718    }
719
720    /// Save index to file
721    pub fn save_to_file(&self, path: &Path) -> Result<()> {
722        let start = std::time::Instant::now();
723
724        let centroids = self.centroids.read();
725        let quantizer = self.quantizer.read();
726        let partitions = self.partitions.read();
727
728        if centroids.is_empty() {
729            return Err(anyhow!("Cannot save empty index"));
730        }
731
732        let num_vectors = self.num_vectors.load(Ordering::Acquire);
733        let num_partitions = centroids.len();
734        let dimension = self.config.dimension;
735        let pq_enabled = quantizer.is_some();
736        let pq_subvectors = if pq_enabled { dimension / 8 } else { 0 };
737
738        // Calculate section offsets
739        let centroids_offset = align_to(HEADER_SIZE, ALIGNMENT);
740        let centroids_size = num_partitions * dimension * 4;
741
742        let codebook_offset = align_to(centroids_offset + centroids_size, ALIGNMENT);
743        let codebook_size = if pq_enabled {
744            // num_subvectors + num_centroids + subvec_dim + codebook data
745            4 + 4 + 4 + (pq_subvectors * NUM_CENTROIDS * 8 * 4)
746        } else {
747            0
748        };
749
750        let posting_index_offset = align_to(codebook_offset + codebook_size, ALIGNMENT);
751        let posting_index_size = num_partitions * POSTING_INDEX_ENTRY_SIZE;
752
753        let posting_data_offset = align_to(posting_index_offset + posting_index_size, ALIGNMENT);
754
755        // Calculate posting data size
756        let entry_size = PostingEntry::serialized_size(pq_subvectors);
757        let posting_data_size: usize = partitions
758            .iter()
759            .map(|p| p.entries.len() * entry_size)
760            .sum();
761
762        let total_size = posting_data_offset + posting_data_size;
763
764        // Create file
765        let file = OpenOptions::new()
766            .read(true)
767            .write(true)
768            .create(true)
769            .truncate(true)
770            .open(path)?;
771        file.set_len(total_size as u64)?;
772
773        let mut mmap = unsafe { MmapMut::map_mut(&file)? };
774
775        // Create header
776        let mut header = SpannHeader::new(
777            num_vectors,
778            num_partitions,
779            dimension,
780            pq_enabled,
781            pq_subvectors,
782            self.config.distance_metric,
783        );
784        header.centroids_offset = centroids_offset as u64;
785        header.codebook_offset = codebook_offset as u64;
786        header.posting_index_offset = posting_index_offset as u64;
787        header.posting_data_offset = posting_data_offset as u64;
788
789        // Write centroids
790        let mut offset = centroids_offset;
791        for centroid in centroids.iter() {
792            for &val in centroid {
793                mmap[offset..offset + 4].copy_from_slice(&val.to_le_bytes());
794                offset += 4;
795            }
796        }
797
798        // Write PQ codebook
799        if let Some(ref pq) = *quantizer {
800            offset = codebook_offset;
801            mmap[offset..offset + 4].copy_from_slice(&(pq_subvectors as u32).to_le_bytes());
802            offset += 4;
803            mmap[offset..offset + 4].copy_from_slice(&(NUM_CENTROIDS as u32).to_le_bytes());
804            offset += 4;
805            mmap[offset..offset + 4].copy_from_slice(&8u32.to_le_bytes()); // subvec_dim
806            offset += 4;
807
808            for subspace_centroids in &pq.centroids {
809                for centroid in subspace_centroids {
810                    for &val in centroid {
811                        mmap[offset..offset + 4].copy_from_slice(&val.to_le_bytes());
812                        offset += 4;
813                    }
814                }
815            }
816        }
817
818        // Write posting index and data
819        let mut data_write_offset: usize = 0;
820        for (partition_id, partition) in partitions.iter().enumerate() {
821            // Write index entry
822            let index_entry_offset = posting_index_offset + partition_id * POSTING_INDEX_ENTRY_SIZE;
823            mmap[index_entry_offset..index_entry_offset + 8]
824                .copy_from_slice(&(data_write_offset as u64).to_le_bytes());
825            mmap[index_entry_offset + 8..index_entry_offset + 12]
826                .copy_from_slice(&(partition.entries.len() as u32).to_le_bytes());
827
828            // Write posting list entries
829            offset = posting_data_offset + data_write_offset;
830            for entry in &partition.entries {
831                mmap[offset..offset + 4].copy_from_slice(&entry.vector_id.to_le_bytes());
832                offset += 4;
833
834                if let Some(ref codes) = entry.pq_codes {
835                    mmap[offset..offset + codes.len()].copy_from_slice(codes);
836                    offset += codes.len();
837                }
838            }
839
840            data_write_offset += partition.entries.len() * entry_size;
841        }
842
843        // Compute checksum and write header
844        let checksum = compute_checksum(&mmap[HEADER_SIZE..]);
845        header.checksum = checksum;
846        mmap[..HEADER_SIZE].copy_from_slice(&header.to_bytes());
847
848        mmap.flush()?;
849
850        info!(
851            "Saved SPANN index: {} vectors, {} partitions, {} bytes in {:?}",
852            num_vectors,
853            num_partitions,
854            total_size,
855            start.elapsed()
856        );
857
858        Ok(())
859    }
860
861    /// Load index from file
862    pub fn load_from_file(path: &Path) -> Result<Self> {
863        let start = std::time::Instant::now();
864
865        if !path.exists() {
866            return Err(anyhow!("Index file not found: {:?}", path));
867        }
868
869        let file = File::open(path)?;
870        let mmap = unsafe { Mmap::map(&file)? };
871
872        // Read and verify header
873        let header = SpannHeader::from_bytes(&mmap[..HEADER_SIZE])?;
874
875        let stored_checksum = header.checksum;
876        let computed_checksum = compute_checksum(&mmap[HEADER_SIZE..]);
877        if stored_checksum != computed_checksum {
878            return Err(anyhow!(
879                "Checksum mismatch: stored={}, computed={}",
880                stored_checksum,
881                computed_checksum
882            ));
883        }
884
885        let num_vectors = header.num_vectors as usize;
886        let num_partitions = header.num_partitions as usize;
887        let dimension = header.dimension as usize;
888        let pq_enabled = header.pq_enabled == 1;
889        let _pq_subvectors = header.pq_subvectors as usize;
890
891        // Read centroids
892        let mut centroids = Vec::with_capacity(num_partitions);
893        let mut offset = header.centroids_offset as usize;
894        for _ in 0..num_partitions {
895            let mut centroid = Vec::with_capacity(dimension);
896            for _ in 0..dimension {
897                let val = f32::from_le_bytes(mmap[offset..offset + 4].try_into()?);
898                centroid.push(val);
899                offset += 4;
900            }
901            centroids.push(centroid);
902        }
903
904        // Read PQ codebook
905        let quantizer = if pq_enabled {
906            offset = header.codebook_offset as usize;
907            let num_subvectors = u32::from_le_bytes(mmap[offset..offset + 4].try_into()?) as usize;
908            offset += 4;
909            let num_centroids = u32::from_le_bytes(mmap[offset..offset + 4].try_into()?) as usize;
910            offset += 4;
911            let subvec_dim = u32::from_le_bytes(mmap[offset..offset + 4].try_into()?) as usize;
912            offset += 4;
913
914            let mut pq_centroids = Vec::with_capacity(num_subvectors);
915            for _ in 0..num_subvectors {
916                let mut subspace = Vec::with_capacity(num_centroids);
917                for _ in 0..num_centroids {
918                    let mut centroid = Vec::with_capacity(subvec_dim);
919                    for _ in 0..subvec_dim {
920                        let val = f32::from_le_bytes(mmap[offset..offset + 4].try_into()?);
921                        centroid.push(val);
922                        offset += 4;
923                    }
924                    subspace.push(centroid);
925                }
926                pq_centroids.push(subspace);
927            }
928
929            let config = PQConfig {
930                dimension,
931                num_subvectors,
932                subvec_dim,
933                num_centroids,
934                kmeans_iterations: 20,
935            };
936
937            Some(ProductQuantizer {
938                config,
939                centroids: pq_centroids,
940                trained: true,
941            })
942        } else {
943            None
944        };
945
946        let config = SpannConfig {
947            dimension,
948            num_partitions: Some(num_partitions),
949            use_pq: pq_enabled,
950            distance_metric: header.distance_metric_enum(),
951            ..Default::default()
952        };
953
954        let index = SpannIndex {
955            config,
956            centroids: Arc::new(RwLock::new(centroids)),
957            quantizer: Arc::new(RwLock::new(quantizer)),
958            partitions: Arc::new(RwLock::new(Vec::new())), // Not loaded - use mmap
959            mmap: Arc::new(RwLock::new(Some(mmap))),
960            num_vectors: AtomicUsize::new(num_vectors),
961            num_partitions: AtomicUsize::new(num_partitions),
962            posting_index_offset: AtomicUsize::new(header.posting_index_offset as usize),
963            posting_data_offset: AtomicUsize::new(header.posting_data_offset as usize),
964        };
965
966        info!(
967            "Loaded SPANN index: {} vectors, {} partitions in {:?}",
968            num_vectors,
969            num_partitions,
970            start.elapsed()
971        );
972
973        Ok(index)
974    }
975
976    /// Insert a single vector into the index
977    pub fn insert(&mut self, vector_id: u32, vector: &[f32]) -> Result<()> {
978        let centroids = self.centroids.read();
979        if centroids.is_empty() {
980            return Err(anyhow!("Cannot insert into empty index - build first"));
981        }
982
983        let partition_id = self.find_nearest_centroid(vector, &centroids);
984        drop(centroids);
985
986        let pq_codes = {
987            let quantizer = self.quantizer.read();
988            if let Some(ref pq) = *quantizer {
989                Some(pq.encode(vector)?)
990            } else {
991                None
992            }
993        };
994
995        let mut partitions = self.partitions.write();
996        if partition_id < partitions.len() {
997            partitions[partition_id].entries.push(PostingEntry {
998                vector_id,
999                pq_codes,
1000            });
1001            self.num_vectors.fetch_add(1, Ordering::Release);
1002        }
1003
1004        Ok(())
1005    }
1006
1007    /// Number of vectors in the index
1008    pub fn len(&self) -> usize {
1009        self.num_vectors.load(Ordering::Acquire)
1010    }
1011
1012    /// Check if index is empty
1013    pub fn is_empty(&self) -> bool {
1014        self.len() == 0
1015    }
1016
1017    /// Number of partitions
1018    pub fn num_partitions(&self) -> usize {
1019        self.num_partitions.load(Ordering::Acquire)
1020    }
1021
1022    /// Verify index file integrity
1023    pub fn verify_index_file(path: &Path) -> Result<bool> {
1024        let file = File::open(path)?;
1025        let mmap = unsafe { Mmap::map(&file)? };
1026
1027        if mmap.len() < HEADER_SIZE {
1028            return Ok(false);
1029        }
1030
1031        let header = SpannHeader::from_bytes(&mmap[..HEADER_SIZE])?;
1032        let stored_checksum = header.checksum;
1033        let computed_checksum = compute_checksum(&mmap[HEADER_SIZE..]);
1034
1035        Ok(stored_checksum == computed_checksum)
1036    }
1037}
1038
1039/// Align offset to boundary
1040fn align_to(offset: usize, alignment: usize) -> usize {
1041    (offset + alignment - 1) & !(alignment - 1)
1042}
1043
1044/// Compute FNV-1a checksum
1045fn compute_checksum(data: &[u8]) -> u64 {
1046    let mut hash: u64 = 0xcbf29ce484222325;
1047    for byte in data {
1048        hash ^= *byte as u64;
1049        hash = hash.wrapping_mul(0x100000001b3);
1050    }
1051    hash
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056    use super::*;
1057    use tempfile::tempdir;
1058
1059    fn generate_random_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
1060        use rand::Rng;
1061        let mut rng = rand::thread_rng();
1062        (0..n)
1063            .map(|_| {
1064                let mut vec: Vec<f32> = (0..dim).map(|_| rng.gen::<f32>()).collect();
1065                // Normalize for cosine distance
1066                let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
1067                if norm > 0.0 {
1068                    vec.iter_mut().for_each(|x| *x /= norm);
1069                }
1070                vec
1071            })
1072            .collect()
1073    }
1074
1075    #[test]
1076    fn test_spann_build_and_search() {
1077        let vectors = generate_random_vectors(1000, 384);
1078
1079        let config = SpannConfig {
1080            dimension: 384,
1081            use_pq: true,
1082            num_probes: 20, // More probes for better recall
1083            ..Default::default()
1084        };
1085
1086        let mut index = SpannIndex::new(config);
1087        index.build(vectors.clone()).unwrap();
1088
1089        // Search - stringent test: query should be in top 10
1090        let results = index.search(&vectors[0], 10).unwrap();
1091
1092        assert!(!results.is_empty(), "Search should return results");
1093        assert_eq!(results.len(), 10, "Should return exactly k results");
1094
1095        // Query vector should be #1 or very close (top 3) with PQ
1096        // Self-distance should be near zero even with quantization
1097        let query_position = results.iter().position(|(id, _)| *id == 0);
1098        assert!(
1099            query_position.is_some() && query_position.unwrap() < 3,
1100            "Query vector should be in top 3 results, found at {:?}, results: {:?}",
1101            query_position,
1102            results.iter().take(5).collect::<Vec<_>>()
1103        );
1104    }
1105
1106    #[test]
1107    fn test_spann_save_and_load() {
1108        let temp_dir = tempdir().unwrap();
1109        let index_path = temp_dir.path().join("test.spann");
1110
1111        let vectors = generate_random_vectors(500, 384);
1112
1113        let config = SpannConfig {
1114            dimension: 384,
1115            use_pq: true,
1116            ..Default::default()
1117        };
1118
1119        let mut index = SpannIndex::new(config);
1120        index.build(vectors.clone()).unwrap();
1121
1122        // Save
1123        index.save_to_file(&index_path).unwrap();
1124        assert!(index_path.exists());
1125
1126        // Verify
1127        assert!(SpannIndex::verify_index_file(&index_path).unwrap());
1128
1129        // Load
1130        let loaded = SpannIndex::load_from_file(&index_path).unwrap();
1131        assert_eq!(loaded.len(), 500);
1132        assert!(loaded.num_partitions() > 0);
1133
1134        // Search loaded index
1135        let results = loaded.search(&vectors[0], 10).unwrap();
1136        assert!(!results.is_empty());
1137    }
1138
1139    #[test]
1140    fn test_partition_count() {
1141        let config = SpannConfig::default();
1142
1143        assert_eq!(config.compute_partitions(100), 10);
1144        assert_eq!(config.compute_partitions(10000), 100);
1145        assert_eq!(config.compute_partitions(1000000), 1000);
1146    }
1147
1148    #[test]
1149    fn test_no_pq() {
1150        let vectors = generate_random_vectors(100, 384);
1151
1152        let config = SpannConfig {
1153            dimension: 384,
1154            use_pq: false,
1155            ..Default::default()
1156        };
1157
1158        let mut index = SpannIndex::new(config);
1159        index.build(vectors).unwrap();
1160
1161        assert!(index.quantizer.read().is_none());
1162    }
1163}