1use anyhow::{anyhow, Result};
63use memmap2::{Mmap, MmapMut};
64use parking_lot::RwLock;
65use rand::seq::SliceRandom;
66use std::collections::BinaryHeap;
67use std::fs::{File, OpenOptions};
68use std::path::Path;
69use std::sync::atomic::{AtomicUsize, Ordering};
70use std::sync::Arc;
71use tracing::info;
72
73use super::pq::{PQConfig, ProductQuantizer, NUM_CENTROIDS};
74use super::vamana::DistanceMetric;
75
76const MAGIC: [u8; 4] = *b"SPAN";
77const VERSION: u32 = 1;
78const HEADER_SIZE: usize = 128;
79const ALIGNMENT: usize = 64;
80const POSTING_INDEX_ENTRY_SIZE: usize = 12; #[derive(Debug, Clone)]
84pub struct SpannConfig {
85 pub dimension: usize,
87 pub num_partitions: Option<usize>,
89 pub use_pq: bool,
91 pub num_probes: usize,
93 pub kmeans_iterations: usize,
95 pub distance_metric: DistanceMetric,
97 pub min_partition_size: usize,
99 pub max_partition_size: usize,
101}
102
103impl Default for SpannConfig {
104 fn default() -> Self {
105 Self {
106 dimension: 384,
107 num_partitions: None, use_pq: true,
109 num_probes: 10,
110 kmeans_iterations: 25,
111 distance_metric: DistanceMetric::NormalizedDotProduct,
112 min_partition_size: 100,
113 max_partition_size: 10000,
114 }
115 }
116}
117
118impl SpannConfig {
119 pub fn minilm() -> Self {
121 Self {
122 dimension: 384,
123 ..Default::default()
124 }
125 }
126
127 pub fn clip() -> Self {
129 Self {
130 dimension: 768,
131 ..Default::default()
132 }
133 }
134
135 pub fn compute_partitions(&self, num_vectors: usize) -> usize {
137 self.num_partitions
138 .unwrap_or_else(|| ((num_vectors as f64).sqrt().ceil() as usize).max(1))
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct PostingEntry {
145 pub vector_id: u32,
147 pub pq_codes: Option<Vec<u8>>,
149}
150
151impl PostingEntry {
152 pub fn serialized_size(pq_subvectors: usize) -> usize {
154 4 + pq_subvectors }
156}
157
158#[derive(Debug, Clone)]
160pub struct Partition {
161 pub id: u32,
163 pub centroid: Vec<f32>,
165 pub entries: Vec<PostingEntry>,
167}
168
169#[repr(C, packed)]
171#[derive(Debug, Clone, Copy)]
172struct SpannHeader {
173 magic: [u8; 4],
174 version: u32,
175 num_vectors: u64,
176 num_partitions: u32,
177 dimension: u32,
178 pq_enabled: u8,
179 pq_subvectors: u32,
180 distance_metric: u8,
181 checksum: u64,
182 centroids_offset: u64,
183 codebook_offset: u64,
184 posting_index_offset: u64,
185 posting_data_offset: u64,
186 reserved: [u8; 60],
187}
188
189impl SpannHeader {
190 fn new(
191 num_vectors: usize,
192 num_partitions: usize,
193 dimension: usize,
194 pq_enabled: bool,
195 pq_subvectors: usize,
196 distance_metric: DistanceMetric,
197 ) -> Self {
198 Self {
199 magic: MAGIC,
200 version: VERSION,
201 num_vectors: num_vectors as u64,
202 num_partitions: num_partitions as u32,
203 dimension: dimension as u32,
204 pq_enabled: if pq_enabled { 1 } else { 0 },
205 pq_subvectors: pq_subvectors as u32,
206 distance_metric: match distance_metric {
207 DistanceMetric::NormalizedDotProduct => 0,
208 DistanceMetric::Euclidean => 1,
209 DistanceMetric::Cosine => 2,
210 },
211 checksum: 0,
212 centroids_offset: 0,
213 codebook_offset: 0,
214 posting_index_offset: 0,
215 posting_data_offset: 0,
216 reserved: [0u8; 60],
217 }
218 }
219
220 fn to_bytes(&self) -> [u8; HEADER_SIZE] {
221 let mut bytes = [0u8; HEADER_SIZE];
222 let mut offset = 0;
223
224 bytes[offset..offset + 4].copy_from_slice(&self.magic);
225 offset += 4;
226 bytes[offset..offset + 4].copy_from_slice(&self.version.to_le_bytes());
227 offset += 4;
228 bytes[offset..offset + 8].copy_from_slice(&self.num_vectors.to_le_bytes());
229 offset += 8;
230 bytes[offset..offset + 4].copy_from_slice(&self.num_partitions.to_le_bytes());
231 offset += 4;
232 bytes[offset..offset + 4].copy_from_slice(&self.dimension.to_le_bytes());
233 offset += 4;
234 bytes[offset] = self.pq_enabled;
235 offset += 1;
236 bytes[offset..offset + 4].copy_from_slice(&self.pq_subvectors.to_le_bytes());
237 offset += 4;
238 bytes[offset] = self.distance_metric;
239 offset += 1;
240 bytes[offset..offset + 8].copy_from_slice(&self.checksum.to_le_bytes());
241 offset += 8;
242 bytes[offset..offset + 8].copy_from_slice(&self.centroids_offset.to_le_bytes());
243 offset += 8;
244 bytes[offset..offset + 8].copy_from_slice(&self.codebook_offset.to_le_bytes());
245 offset += 8;
246 bytes[offset..offset + 8].copy_from_slice(&self.posting_index_offset.to_le_bytes());
247 offset += 8;
248 bytes[offset..offset + 8].copy_from_slice(&self.posting_data_offset.to_le_bytes());
249 bytes
251 }
252
253 fn from_bytes(bytes: &[u8]) -> Result<Self> {
254 if bytes.len() < HEADER_SIZE {
255 return Err(anyhow!("Header too small"));
256 }
257
258 let magic: [u8; 4] = bytes[0..4].try_into()?;
259 if magic != MAGIC {
260 return Err(anyhow!("Invalid magic bytes: {:?}", magic));
261 }
262
263 let version = u32::from_le_bytes(bytes[4..8].try_into()?);
264 if version != VERSION {
265 return Err(anyhow!("Unsupported version: {}", version));
266 }
267
268 let mut offset = 8;
269 let num_vectors = u64::from_le_bytes(bytes[offset..offset + 8].try_into()?);
270 offset += 8;
271 let num_partitions = u32::from_le_bytes(bytes[offset..offset + 4].try_into()?);
272 offset += 4;
273 let dimension = u32::from_le_bytes(bytes[offset..offset + 4].try_into()?);
274 offset += 4;
275 let pq_enabled = bytes[offset];
276 offset += 1;
277 let pq_subvectors = u32::from_le_bytes(bytes[offset..offset + 4].try_into()?);
278 offset += 4;
279 let distance_metric = bytes[offset];
280 offset += 1;
281 let checksum = u64::from_le_bytes(bytes[offset..offset + 8].try_into()?);
282 offset += 8;
283 let centroids_offset = u64::from_le_bytes(bytes[offset..offset + 8].try_into()?);
284 offset += 8;
285 let codebook_offset = u64::from_le_bytes(bytes[offset..offset + 8].try_into()?);
286 offset += 8;
287 let posting_index_offset = u64::from_le_bytes(bytes[offset..offset + 8].try_into()?);
288 offset += 8;
289 let posting_data_offset = u64::from_le_bytes(bytes[offset..offset + 8].try_into()?);
290
291 Ok(Self {
292 magic,
293 version,
294 num_vectors,
295 num_partitions,
296 dimension,
297 pq_enabled,
298 pq_subvectors,
299 distance_metric,
300 checksum,
301 centroids_offset,
302 codebook_offset,
303 posting_index_offset,
304 posting_data_offset,
305 reserved: [0u8; 60],
306 })
307 }
308
309 fn distance_metric_enum(&self) -> DistanceMetric {
310 match self.distance_metric {
311 0 => DistanceMetric::NormalizedDotProduct,
312 1 => DistanceMetric::Euclidean,
313 2 => DistanceMetric::Cosine,
314 _ => DistanceMetric::NormalizedDotProduct,
315 }
316 }
317}
318
319pub struct SpannIndex {
321 pub config: SpannConfig,
323 centroids: Arc<RwLock<Vec<Vec<f32>>>>,
325 quantizer: Arc<RwLock<Option<ProductQuantizer>>>,
327 partitions: Arc<RwLock<Vec<Partition>>>,
329 mmap: Arc<RwLock<Option<Mmap>>>,
331 num_vectors: AtomicUsize,
333 num_partitions: AtomicUsize,
335 posting_index_offset: AtomicUsize,
337 posting_data_offset: AtomicUsize,
339}
340
341impl SpannIndex {
342 pub fn new(config: SpannConfig) -> Self {
344 Self {
345 config,
346 centroids: Arc::new(RwLock::new(Vec::new())),
347 quantizer: Arc::new(RwLock::new(None)),
348 partitions: Arc::new(RwLock::new(Vec::new())),
349 mmap: Arc::new(RwLock::new(None)),
350 num_vectors: AtomicUsize::new(0),
351 num_partitions: AtomicUsize::new(0),
352 posting_index_offset: AtomicUsize::new(0),
353 posting_data_offset: AtomicUsize::new(0),
354 }
355 }
356
357 pub fn build(&mut self, vectors: Vec<Vec<f32>>) -> Result<()> {
363 if vectors.is_empty() {
364 return Err(anyhow!("Cannot build index from empty vectors"));
365 }
366
367 let n = vectors.len();
368 let dim = vectors[0].len();
369
370 if dim != self.config.dimension {
371 return Err(anyhow!(
372 "Vector dimension {} doesn't match config {}",
373 dim,
374 self.config.dimension
375 ));
376 }
377
378 let num_partitions = self.config.compute_partitions(n);
379 info!(
380 "Building SPANN index: {} vectors, {} partitions, PQ={}",
381 n, num_partitions, self.config.use_pq
382 );
383
384 let start = std::time::Instant::now();
385
386 let centroids = self.kmeans_cluster(&vectors, num_partitions)?;
388
389 let quantizer = if self.config.use_pq {
391 info!("Training PQ quantizer...");
392 let pq_config = PQConfig::for_dimension(dim);
393 let pq = ProductQuantizer::train(pq_config, &vectors)?;
394 Some(pq)
395 } else {
396 None
397 };
398
399 let mut partitions: Vec<Partition> = centroids
401 .iter()
402 .enumerate()
403 .map(|(i, c)| Partition {
404 id: i as u32,
405 centroid: c.clone(),
406 entries: Vec::new(),
407 })
408 .collect();
409
410 for (vec_id, vector) in vectors.iter().enumerate() {
411 let partition_id = self.find_nearest_centroid(vector, ¢roids);
412
413 let pq_codes = if let Some(ref pq) = quantizer {
414 Some(pq.encode(vector)?)
415 } else {
416 None
417 };
418
419 partitions[partition_id].entries.push(PostingEntry {
420 vector_id: vec_id as u32,
421 pq_codes,
422 });
423 }
424
425 let sizes: Vec<usize> = partitions.iter().map(|p| p.entries.len()).collect();
427 let min_size = sizes.iter().min().copied().unwrap_or(0);
428 let max_size = sizes.iter().max().copied().unwrap_or(0);
429 let avg_size = if !sizes.is_empty() {
430 sizes.iter().sum::<usize>() / sizes.len()
431 } else {
432 0
433 };
434 info!(
435 "Partition distribution: min={}, max={}, avg={}",
436 min_size, max_size, avg_size
437 );
438
439 *self.centroids.write() = centroids;
441 *self.quantizer.write() = quantizer;
442 *self.partitions.write() = partitions;
443 self.num_vectors.store(n, Ordering::Release);
444 self.num_partitions.store(num_partitions, Ordering::Release);
445
446 info!("SPANN build complete in {:?}", start.elapsed());
447
448 Ok(())
449 }
450
451 fn kmeans_cluster(&self, vectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>> {
453 let n = vectors.len();
454 let dim = vectors[0].len();
455 let iterations = self.config.kmeans_iterations;
456
457 let mut rng = rand::thread_rng();
459 let mut indices: Vec<usize> = (0..n).collect();
460 indices.shuffle(&mut rng);
461
462 let mut centroids: Vec<Vec<f32>> = indices
463 .iter()
464 .take(k)
465 .map(|&i| vectors[i].clone())
466 .collect();
467
468 while centroids.len() < k {
470 let idx = indices[centroids.len() % n];
471 centroids.push(vectors[idx].clone());
472 }
473
474 let mut assignments = vec![0usize; n];
475
476 for iter in 0..iterations {
478 let mut changed = 0usize;
479
480 for (i, vec) in vectors.iter().enumerate() {
482 let new_assignment = self.find_nearest_centroid(vec, ¢roids);
483 if new_assignment != assignments[i] {
484 changed += 1;
485 }
486 assignments[i] = new_assignment;
487 }
488
489 let mut new_centroids: Vec<Vec<f32>> = vec![vec![0.0; dim]; k];
491 let mut counts = vec![0usize; k];
492
493 for (i, vec) in vectors.iter().enumerate() {
494 let c = assignments[i];
495 counts[c] += 1;
496 for (j, &v) in vec.iter().enumerate() {
497 new_centroids[c][j] += v;
498 }
499 }
500
501 for c in 0..k {
502 if counts[c] > 0 {
503 for j in 0..dim {
504 new_centroids[c][j] /= counts[c] as f32;
505 }
506 centroids[c] = new_centroids[c].clone();
507 }
508 }
509
510 if iter % 5 == 0 {
511 info!(
512 "K-means iter {}/{}: {} assignments changed",
513 iter + 1,
514 iterations,
515 changed
516 );
517 }
518
519 if changed == 0 {
521 info!("K-means converged at iteration {}", iter + 1);
522 break;
523 }
524 }
525
526 Ok(centroids)
527 }
528
529 #[inline]
531 fn find_nearest_centroid(&self, vector: &[f32], centroids: &[Vec<f32>]) -> usize {
532 let mut best_idx = 0;
533 let mut best_dist = f32::MAX;
534
535 for (i, centroid) in centroids.iter().enumerate() {
536 let dist = self.compute_distance(vector, centroid);
537 if dist < best_dist {
538 best_dist = dist;
539 best_idx = i;
540 }
541 }
542
543 best_idx
544 }
545
546 #[inline]
548 fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
549 match self.config.distance_metric {
550 DistanceMetric::Euclidean => a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum(),
551 DistanceMetric::NormalizedDotProduct | DistanceMetric::Cosine => {
552 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
554 1.0 - dot
555 }
556 }
557 }
558
559 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>> {
561 let centroids = self.centroids.read();
562 if centroids.is_empty() {
563 return Ok(Vec::new());
564 }
565
566 let mut partition_distances: Vec<(usize, f32)> = centroids
568 .iter()
569 .enumerate()
570 .map(|(i, c)| (i, self.compute_distance(query, c)))
571 .collect();
572
573 partition_distances.sort_by(|a, b| a.1.total_cmp(&b.1));
574 let probe_partitions: Vec<usize> = partition_distances
575 .iter()
576 .take(self.config.num_probes)
577 .map(|(i, _)| *i)
578 .collect();
579
580 let quantizer = self.quantizer.read();
582 let partitions = self.partitions.read();
583
584 let distance_table = if let Some(ref pq) = *quantizer {
586 Some(pq.build_distance_table(query)?)
587 } else {
588 anyhow::bail!(
589 "SPANN search requires PQ quantizer but use_pq is disabled. \
590 PostingEntry stores only PQ codes, not original vectors."
591 );
592 };
593
594 let mut heap: BinaryHeap<(ordered_float::OrderedFloat<f32>, u32)> =
597 BinaryHeap::with_capacity(k);
598
599 if !partitions.is_empty() {
600 for &partition_id in &probe_partitions {
602 if partition_id >= partitions.len() {
603 continue;
604 }
605
606 for entry in &partitions[partition_id].entries {
607 let dist = if let (Some(ref table), Some(ref codes)) =
608 (&distance_table, &entry.pq_codes)
609 {
610 quantizer
611 .as_ref()
612 .unwrap()
613 .distance_with_table(table, codes)
614 } else {
615 continue;
617 };
618
619 heap.push((ordered_float::OrderedFloat(dist), entry.vector_id));
620 if heap.len() > k {
621 heap.pop(); }
623 }
624 }
625 } else {
626 let mmap_guard = self.mmap.read();
628 if let Some(ref mmap) = *mmap_guard {
629 let pq_subvectors = if quantizer.is_some() {
630 self.config.dimension / 8 } else {
632 0
633 };
634
635 for &partition_id in &probe_partitions {
636 let entries = self.read_posting_list(mmap, partition_id, pq_subvectors)?;
637
638 for entry in entries {
639 let dist = if let (Some(ref table), Some(ref codes)) =
640 (&distance_table, &entry.pq_codes)
641 {
642 quantizer
643 .as_ref()
644 .unwrap()
645 .distance_with_table(table, codes)
646 } else {
647 continue;
648 };
649
650 heap.push((ordered_float::OrderedFloat(dist), entry.vector_id));
651 if heap.len() > k {
652 heap.pop(); }
654 }
655 }
656 }
657 }
658
659 let mut results: Vec<(u32, f32)> = heap.into_iter().map(|(d, id)| (id, d.0)).collect();
661 results.sort_by(|a, b| a.1.total_cmp(&b.1));
662
663 Ok(results)
664 }
665
666 fn read_posting_list(
668 &self,
669 mmap: &Mmap,
670 partition_id: usize,
671 pq_subvectors: usize,
672 ) -> Result<Vec<PostingEntry>> {
673 let index_offset = self.posting_index_offset.load(Ordering::Acquire);
674 let data_offset = self.posting_data_offset.load(Ordering::Acquire);
675
676 let entry_offset = index_offset + partition_id * POSTING_INDEX_ENTRY_SIZE;
678 if entry_offset + POSTING_INDEX_ENTRY_SIZE > mmap.len() {
679 return Err(anyhow!("Posting index out of bounds"));
680 }
681
682 let list_offset =
683 u64::from_le_bytes(mmap[entry_offset..entry_offset + 8].try_into()?) as usize;
684 let count =
685 u32::from_le_bytes(mmap[entry_offset + 8..entry_offset + 12].try_into()?) as usize;
686
687 let entry_size = PostingEntry::serialized_size(pq_subvectors);
689 let list_start = data_offset + list_offset;
690 let list_end = list_start + count * entry_size;
691
692 if list_end > mmap.len() {
693 return Err(anyhow!("Posting list data out of bounds"));
694 }
695
696 let mut entries = Vec::with_capacity(count);
697 let mut offset = list_start;
698
699 for _ in 0..count {
700 let vector_id = u32::from_le_bytes(mmap[offset..offset + 4].try_into()?);
701 offset += 4;
702
703 let pq_codes = if pq_subvectors > 0 {
704 let codes = mmap[offset..offset + pq_subvectors].to_vec();
705 offset += pq_subvectors;
706 Some(codes)
707 } else {
708 None
709 };
710
711 entries.push(PostingEntry {
712 vector_id,
713 pq_codes,
714 });
715 }
716
717 Ok(entries)
718 }
719
720 pub fn save_to_file(&self, path: &Path) -> Result<()> {
722 let start = std::time::Instant::now();
723
724 let centroids = self.centroids.read();
725 let quantizer = self.quantizer.read();
726 let partitions = self.partitions.read();
727
728 if centroids.is_empty() {
729 return Err(anyhow!("Cannot save empty index"));
730 }
731
732 let num_vectors = self.num_vectors.load(Ordering::Acquire);
733 let num_partitions = centroids.len();
734 let dimension = self.config.dimension;
735 let pq_enabled = quantizer.is_some();
736 let pq_subvectors = if pq_enabled { dimension / 8 } else { 0 };
737
738 let centroids_offset = align_to(HEADER_SIZE, ALIGNMENT);
740 let centroids_size = num_partitions * dimension * 4;
741
742 let codebook_offset = align_to(centroids_offset + centroids_size, ALIGNMENT);
743 let codebook_size = if pq_enabled {
744 4 + 4 + 4 + (pq_subvectors * NUM_CENTROIDS * 8 * 4)
746 } else {
747 0
748 };
749
750 let posting_index_offset = align_to(codebook_offset + codebook_size, ALIGNMENT);
751 let posting_index_size = num_partitions * POSTING_INDEX_ENTRY_SIZE;
752
753 let posting_data_offset = align_to(posting_index_offset + posting_index_size, ALIGNMENT);
754
755 let entry_size = PostingEntry::serialized_size(pq_subvectors);
757 let posting_data_size: usize = partitions
758 .iter()
759 .map(|p| p.entries.len() * entry_size)
760 .sum();
761
762 let total_size = posting_data_offset + posting_data_size;
763
764 let file = OpenOptions::new()
766 .read(true)
767 .write(true)
768 .create(true)
769 .truncate(true)
770 .open(path)?;
771 file.set_len(total_size as u64)?;
772
773 let mut mmap = unsafe { MmapMut::map_mut(&file)? };
774
775 let mut header = SpannHeader::new(
777 num_vectors,
778 num_partitions,
779 dimension,
780 pq_enabled,
781 pq_subvectors,
782 self.config.distance_metric,
783 );
784 header.centroids_offset = centroids_offset as u64;
785 header.codebook_offset = codebook_offset as u64;
786 header.posting_index_offset = posting_index_offset as u64;
787 header.posting_data_offset = posting_data_offset as u64;
788
789 let mut offset = centroids_offset;
791 for centroid in centroids.iter() {
792 for &val in centroid {
793 mmap[offset..offset + 4].copy_from_slice(&val.to_le_bytes());
794 offset += 4;
795 }
796 }
797
798 if let Some(ref pq) = *quantizer {
800 offset = codebook_offset;
801 mmap[offset..offset + 4].copy_from_slice(&(pq_subvectors as u32).to_le_bytes());
802 offset += 4;
803 mmap[offset..offset + 4].copy_from_slice(&(NUM_CENTROIDS as u32).to_le_bytes());
804 offset += 4;
805 mmap[offset..offset + 4].copy_from_slice(&8u32.to_le_bytes()); offset += 4;
807
808 for subspace_centroids in &pq.centroids {
809 for centroid in subspace_centroids {
810 for &val in centroid {
811 mmap[offset..offset + 4].copy_from_slice(&val.to_le_bytes());
812 offset += 4;
813 }
814 }
815 }
816 }
817
818 let mut data_write_offset: usize = 0;
820 for (partition_id, partition) in partitions.iter().enumerate() {
821 let index_entry_offset = posting_index_offset + partition_id * POSTING_INDEX_ENTRY_SIZE;
823 mmap[index_entry_offset..index_entry_offset + 8]
824 .copy_from_slice(&(data_write_offset as u64).to_le_bytes());
825 mmap[index_entry_offset + 8..index_entry_offset + 12]
826 .copy_from_slice(&(partition.entries.len() as u32).to_le_bytes());
827
828 offset = posting_data_offset + data_write_offset;
830 for entry in &partition.entries {
831 mmap[offset..offset + 4].copy_from_slice(&entry.vector_id.to_le_bytes());
832 offset += 4;
833
834 if let Some(ref codes) = entry.pq_codes {
835 mmap[offset..offset + codes.len()].copy_from_slice(codes);
836 offset += codes.len();
837 }
838 }
839
840 data_write_offset += partition.entries.len() * entry_size;
841 }
842
843 let checksum = compute_checksum(&mmap[HEADER_SIZE..]);
845 header.checksum = checksum;
846 mmap[..HEADER_SIZE].copy_from_slice(&header.to_bytes());
847
848 mmap.flush()?;
849
850 info!(
851 "Saved SPANN index: {} vectors, {} partitions, {} bytes in {:?}",
852 num_vectors,
853 num_partitions,
854 total_size,
855 start.elapsed()
856 );
857
858 Ok(())
859 }
860
861 pub fn load_from_file(path: &Path) -> Result<Self> {
863 let start = std::time::Instant::now();
864
865 if !path.exists() {
866 return Err(anyhow!("Index file not found: {:?}", path));
867 }
868
869 let file = File::open(path)?;
870 let mmap = unsafe { Mmap::map(&file)? };
871
872 let header = SpannHeader::from_bytes(&mmap[..HEADER_SIZE])?;
874
875 let stored_checksum = header.checksum;
876 let computed_checksum = compute_checksum(&mmap[HEADER_SIZE..]);
877 if stored_checksum != computed_checksum {
878 return Err(anyhow!(
879 "Checksum mismatch: stored={}, computed={}",
880 stored_checksum,
881 computed_checksum
882 ));
883 }
884
885 let num_vectors = header.num_vectors as usize;
886 let num_partitions = header.num_partitions as usize;
887 let dimension = header.dimension as usize;
888 let pq_enabled = header.pq_enabled == 1;
889 let _pq_subvectors = header.pq_subvectors as usize;
890
891 let mut centroids = Vec::with_capacity(num_partitions);
893 let mut offset = header.centroids_offset as usize;
894 for _ in 0..num_partitions {
895 let mut centroid = Vec::with_capacity(dimension);
896 for _ in 0..dimension {
897 let val = f32::from_le_bytes(mmap[offset..offset + 4].try_into()?);
898 centroid.push(val);
899 offset += 4;
900 }
901 centroids.push(centroid);
902 }
903
904 let quantizer = if pq_enabled {
906 offset = header.codebook_offset as usize;
907 let num_subvectors = u32::from_le_bytes(mmap[offset..offset + 4].try_into()?) as usize;
908 offset += 4;
909 let num_centroids = u32::from_le_bytes(mmap[offset..offset + 4].try_into()?) as usize;
910 offset += 4;
911 let subvec_dim = u32::from_le_bytes(mmap[offset..offset + 4].try_into()?) as usize;
912 offset += 4;
913
914 let mut pq_centroids = Vec::with_capacity(num_subvectors);
915 for _ in 0..num_subvectors {
916 let mut subspace = Vec::with_capacity(num_centroids);
917 for _ in 0..num_centroids {
918 let mut centroid = Vec::with_capacity(subvec_dim);
919 for _ in 0..subvec_dim {
920 let val = f32::from_le_bytes(mmap[offset..offset + 4].try_into()?);
921 centroid.push(val);
922 offset += 4;
923 }
924 subspace.push(centroid);
925 }
926 pq_centroids.push(subspace);
927 }
928
929 let config = PQConfig {
930 dimension,
931 num_subvectors,
932 subvec_dim,
933 num_centroids,
934 kmeans_iterations: 20,
935 };
936
937 Some(ProductQuantizer {
938 config,
939 centroids: pq_centroids,
940 trained: true,
941 })
942 } else {
943 None
944 };
945
946 let config = SpannConfig {
947 dimension,
948 num_partitions: Some(num_partitions),
949 use_pq: pq_enabled,
950 distance_metric: header.distance_metric_enum(),
951 ..Default::default()
952 };
953
954 let index = SpannIndex {
955 config,
956 centroids: Arc::new(RwLock::new(centroids)),
957 quantizer: Arc::new(RwLock::new(quantizer)),
958 partitions: Arc::new(RwLock::new(Vec::new())), mmap: Arc::new(RwLock::new(Some(mmap))),
960 num_vectors: AtomicUsize::new(num_vectors),
961 num_partitions: AtomicUsize::new(num_partitions),
962 posting_index_offset: AtomicUsize::new(header.posting_index_offset as usize),
963 posting_data_offset: AtomicUsize::new(header.posting_data_offset as usize),
964 };
965
966 info!(
967 "Loaded SPANN index: {} vectors, {} partitions in {:?}",
968 num_vectors,
969 num_partitions,
970 start.elapsed()
971 );
972
973 Ok(index)
974 }
975
976 pub fn insert(&mut self, vector_id: u32, vector: &[f32]) -> Result<()> {
978 let centroids = self.centroids.read();
979 if centroids.is_empty() {
980 return Err(anyhow!("Cannot insert into empty index - build first"));
981 }
982
983 let partition_id = self.find_nearest_centroid(vector, ¢roids);
984 drop(centroids);
985
986 let pq_codes = {
987 let quantizer = self.quantizer.read();
988 if let Some(ref pq) = *quantizer {
989 Some(pq.encode(vector)?)
990 } else {
991 None
992 }
993 };
994
995 let mut partitions = self.partitions.write();
996 if partition_id < partitions.len() {
997 partitions[partition_id].entries.push(PostingEntry {
998 vector_id,
999 pq_codes,
1000 });
1001 self.num_vectors.fetch_add(1, Ordering::Release);
1002 }
1003
1004 Ok(())
1005 }
1006
1007 pub fn len(&self) -> usize {
1009 self.num_vectors.load(Ordering::Acquire)
1010 }
1011
1012 pub fn is_empty(&self) -> bool {
1014 self.len() == 0
1015 }
1016
1017 pub fn num_partitions(&self) -> usize {
1019 self.num_partitions.load(Ordering::Acquire)
1020 }
1021
1022 pub fn verify_index_file(path: &Path) -> Result<bool> {
1024 let file = File::open(path)?;
1025 let mmap = unsafe { Mmap::map(&file)? };
1026
1027 if mmap.len() < HEADER_SIZE {
1028 return Ok(false);
1029 }
1030
1031 let header = SpannHeader::from_bytes(&mmap[..HEADER_SIZE])?;
1032 let stored_checksum = header.checksum;
1033 let computed_checksum = compute_checksum(&mmap[HEADER_SIZE..]);
1034
1035 Ok(stored_checksum == computed_checksum)
1036 }
1037}
1038
1039fn align_to(offset: usize, alignment: usize) -> usize {
1041 (offset + alignment - 1) & !(alignment - 1)
1042}
1043
1044fn compute_checksum(data: &[u8]) -> u64 {
1046 let mut hash: u64 = 0xcbf29ce484222325;
1047 for byte in data {
1048 hash ^= *byte as u64;
1049 hash = hash.wrapping_mul(0x100000001b3);
1050 }
1051 hash
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056 use super::*;
1057 use tempfile::tempdir;
1058
1059 fn generate_random_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
1060 use rand::Rng;
1061 let mut rng = rand::thread_rng();
1062 (0..n)
1063 .map(|_| {
1064 let mut vec: Vec<f32> = (0..dim).map(|_| rng.gen::<f32>()).collect();
1065 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
1067 if norm > 0.0 {
1068 vec.iter_mut().for_each(|x| *x /= norm);
1069 }
1070 vec
1071 })
1072 .collect()
1073 }
1074
1075 #[test]
1076 fn test_spann_build_and_search() {
1077 let vectors = generate_random_vectors(1000, 384);
1078
1079 let config = SpannConfig {
1080 dimension: 384,
1081 use_pq: true,
1082 num_probes: 20, ..Default::default()
1084 };
1085
1086 let mut index = SpannIndex::new(config);
1087 index.build(vectors.clone()).unwrap();
1088
1089 let results = index.search(&vectors[0], 10).unwrap();
1091
1092 assert!(!results.is_empty(), "Search should return results");
1093 assert_eq!(results.len(), 10, "Should return exactly k results");
1094
1095 let query_position = results.iter().position(|(id, _)| *id == 0);
1098 assert!(
1099 query_position.is_some() && query_position.unwrap() < 3,
1100 "Query vector should be in top 3 results, found at {:?}, results: {:?}",
1101 query_position,
1102 results.iter().take(5).collect::<Vec<_>>()
1103 );
1104 }
1105
1106 #[test]
1107 fn test_spann_save_and_load() {
1108 let temp_dir = tempdir().unwrap();
1109 let index_path = temp_dir.path().join("test.spann");
1110
1111 let vectors = generate_random_vectors(500, 384);
1112
1113 let config = SpannConfig {
1114 dimension: 384,
1115 use_pq: true,
1116 ..Default::default()
1117 };
1118
1119 let mut index = SpannIndex::new(config);
1120 index.build(vectors.clone()).unwrap();
1121
1122 index.save_to_file(&index_path).unwrap();
1124 assert!(index_path.exists());
1125
1126 assert!(SpannIndex::verify_index_file(&index_path).unwrap());
1128
1129 let loaded = SpannIndex::load_from_file(&index_path).unwrap();
1131 assert_eq!(loaded.len(), 500);
1132 assert!(loaded.num_partitions() > 0);
1133
1134 let results = loaded.search(&vectors[0], 10).unwrap();
1136 assert!(!results.is_empty());
1137 }
1138
1139 #[test]
1140 fn test_partition_count() {
1141 let config = SpannConfig::default();
1142
1143 assert_eq!(config.compute_partitions(100), 10);
1144 assert_eq!(config.compute_partitions(10000), 100);
1145 assert_eq!(config.compute_partitions(1000000), 1000);
1146 }
1147
1148 #[test]
1149 fn test_no_pq() {
1150 let vectors = generate_random_vectors(100, 384);
1151
1152 let config = SpannConfig {
1153 dimension: 384,
1154 use_pq: false,
1155 ..Default::default()
1156 };
1157
1158 let mut index = SpannIndex::new(config);
1159 index.build(vectors).unwrap();
1160
1161 assert!(index.quantizer.read().is_none());
1162 }
1163}