1use std::cmp::Ordering;
29use std::collections::BinaryHeap;
30
31use crate::error::{RabitqError, Result};
32use crate::quantize::BinaryCode;
33use crate::rotation::{normalize_inplace, RandomRotation, RandomRotationKind};
34
35#[derive(Debug, Clone, PartialEq)]
37pub struct SearchResult {
38 pub id: usize,
39 pub score: f32,
41}
42
43pub trait AnnIndex: Send + Sync {
45 fn add(&mut self, id: usize, vector: Vec<f32>) -> Result<()>;
46 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>>;
47 fn len(&self) -> usize;
48 fn is_empty(&self) -> bool {
49 self.len() == 0
50 }
51 fn dim(&self) -> usize;
52 fn memory_bytes(&self) -> usize;
55}
56
57#[inline]
60fn cmp_score_asc(a: f32, b: f32) -> Ordering {
61 a.total_cmp(&b)
64}
65
66struct TopK {
71 k: usize,
72 heap: BinaryHeap<HeapEntry>,
75}
76
77#[derive(Debug, Clone, Copy)]
80struct HeapEntry {
81 id: usize,
82 score: f32,
83 pos: u32,
84}
85impl PartialEq for HeapEntry {
86 fn eq(&self, o: &Self) -> bool {
87 self.score == o.score && self.id == o.id
88 }
89}
90impl Eq for HeapEntry {}
91impl Ord for HeapEntry {
92 fn cmp(&self, o: &Self) -> Ordering {
93 self.score.total_cmp(&o.score).then(self.id.cmp(&o.id))
94 }
95}
96impl PartialOrd for HeapEntry {
97 fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
98 Some(self.cmp(o))
99 }
100}
101
102impl TopK {
103 fn new(k: usize) -> Self {
104 Self {
105 k: k.max(1),
106 heap: BinaryHeap::with_capacity(k.max(1) + 1),
107 }
108 }
109 #[inline]
110 fn push(&mut self, id: usize, score: f32) {
111 self.push_raw(id, score, 0);
112 }
113 #[inline]
115 fn push_raw(&mut self, id: usize, score: f32, pos: usize) {
116 if self.heap.len() < self.k {
117 self.heap.push(HeapEntry {
118 id,
119 score,
120 pos: pos as u32,
121 });
122 return;
123 }
124 let worst = self.heap.peek().unwrap().score;
125 if score.total_cmp(&worst) == Ordering::Less {
126 self.heap.pop();
127 self.heap.push(HeapEntry {
128 id,
129 score,
130 pos: pos as u32,
131 });
132 }
133 }
134 fn into_sorted_asc(self) -> Vec<SearchResult> {
135 let mut v: Vec<SearchResult> = self
136 .heap
137 .into_iter()
138 .map(|e| SearchResult {
139 id: e.id,
140 score: e.score,
141 })
142 .collect();
143 v.sort_unstable_by(|a, b| cmp_score_asc(a.score, b.score));
144 v
145 }
146 fn into_sorted_with_pos(self) -> Vec<(u32, u32, f32)> {
149 let mut v: Vec<(u32, u32, f32)> = self
150 .heap
151 .into_iter()
152 .map(|e| (e.pos, e.id as u32, e.score))
153 .collect();
154 v.sort_unstable_by(|a, b| cmp_score_asc(a.2, b.2));
155 v
156 }
157}
158
159pub struct FlatF32Index {
162 dim: usize,
163 vectors: Vec<(usize, Vec<f32>)>,
164}
165
166impl FlatF32Index {
167 pub fn new(dim: usize) -> Self {
168 Self {
169 dim,
170 vectors: Vec::new(),
171 }
172 }
173}
174
175#[inline]
176fn sq_l2(a: &[f32], b: &[f32]) -> f32 {
177 a.iter()
178 .zip(b.iter())
179 .map(|(&x, &y)| (x - y) * (x - y))
180 .sum()
181}
182
183impl AnnIndex for FlatF32Index {
184 fn add(&mut self, id: usize, vector: Vec<f32>) -> Result<()> {
185 if vector.len() != self.dim {
186 return Err(RabitqError::DimensionMismatch {
187 expected: self.dim,
188 actual: vector.len(),
189 });
190 }
191 self.vectors.push((id, vector));
192 Ok(())
193 }
194
195 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
196 if self.vectors.is_empty() {
197 return Err(RabitqError::EmptyIndex);
198 }
199 if query.len() != self.dim {
200 return Err(RabitqError::DimensionMismatch {
201 expected: self.dim,
202 actual: query.len(),
203 });
204 }
205 let k_eff = k.min(self.vectors.len());
206 let mut top = TopK::new(k_eff);
207 for (id, v) in &self.vectors {
208 top.push(*id, sq_l2(query, v));
209 }
210 Ok(top.into_sorted_asc())
211 }
212
213 fn len(&self) -> usize {
214 self.vectors.len()
215 }
216 fn dim(&self) -> usize {
217 self.dim
218 }
219 fn memory_bytes(&self) -> usize {
220 self.vectors.len() * (self.dim * 4 + 16)
223 }
224}
225
226pub struct RabitqIndex {
248 dim: usize,
249 n_words: usize,
250 rotation: RandomRotation,
251 ids: Vec<u32>,
252 norms: Vec<f32>,
253 packed: Vec<u64>,
255 last_word_mask: u64,
257 cos_lut: Vec<f32>,
259}
260
261fn build_last_word_mask(dim: usize) -> u64 {
262 let n_words = (dim + 63) / 64;
263 if n_words == 0 {
264 return 0;
265 }
266 let valid_bits = dim - 64 * (n_words - 1);
267 if valid_bits == 64 {
268 !0u64
269 } else {
270 !0u64 << (64 - valid_bits)
271 }
272}
273
274fn build_cos_lut(dim: usize) -> Vec<f32> {
275 use std::f32::consts::PI;
276 let d = dim as f32;
277 (0..=dim)
278 .map(|b| (PI * (1.0 - b as f32 / d)).cos())
279 .collect()
280}
281
282impl RabitqIndex {
283 pub fn new(dim: usize, seed: u64) -> Self {
286 Self::new_with_rotation(dim, seed, RandomRotationKind::HaarDense)
287 }
288
289 pub fn new_with_rotation(dim: usize, seed: u64, kind: RandomRotationKind) -> Self {
295 let n_words = (dim + 63) / 64;
296 let rotation = match kind {
297 RandomRotationKind::HaarDense => RandomRotation::random(dim, seed),
298 RandomRotationKind::HadamardSigned => RandomRotation::hadamard(dim, seed),
299 };
300 Self {
301 dim,
302 n_words,
303 rotation,
304 ids: Vec::new(),
305 norms: Vec::new(),
306 packed: Vec::new(),
307 last_word_mask: build_last_word_mask(dim),
308 cos_lut: build_cos_lut(dim),
309 }
310 }
311
312 pub fn encode_vector(&self, v: &[f32]) -> BinaryCode {
316 let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
317 let mut unit = v.to_vec();
318 normalize_inplace(&mut unit);
319 let rotated = self.rotation.apply(&unit);
320 BinaryCode::encode(&rotated, norm)
321 }
322
323 pub fn encode_query_packed(&self, q: &[f32]) -> (Vec<u64>, f32) {
327 use std::cell::RefCell;
334 thread_local! {
335 static SCRATCH: RefCell<(Vec<f32>, Vec<f32>)> =
336 const { RefCell::new((Vec::new(), Vec::new())) };
337 }
338 let norm: f32 = q.iter().map(|&x| x * x).sum::<f32>().sqrt();
339 let dim = q.len();
340 let mut words = vec![0u64; self.n_words];
341 SCRATCH.with(|s| {
342 let mut s = s.borrow_mut();
343 let (unit, rotated) = &mut *s;
344 unit.clear();
345 unit.extend_from_slice(q);
346 normalize_inplace(unit);
347 if rotated.len() != dim {
348 rotated.resize(dim, 0.0);
349 }
350 self.rotation.apply_into(unit, rotated);
351 for (i, &v) in rotated.iter().enumerate() {
352 if v >= 0.0 {
353 words[i / 64] |= 1u64 << (63 - (i % 64));
354 }
355 }
356 });
357 (words, norm.max(1e-10))
358 }
359
360 pub fn encode_query(&self, q: &[f32]) -> BinaryCode {
363 let (words, norm) = self.encode_query_packed(q);
364 BinaryCode {
365 words,
366 norm,
367 dim: self.dim,
368 }
369 }
370
371 pub fn prepare_query_f32(&self, q: &[f32]) -> (Vec<f32>, f32) {
374 let norm: f32 = q.iter().map(|&x| x * x).sum::<f32>().sqrt();
375 let mut unit = q.to_vec();
376 normalize_inplace(&mut unit);
377 (self.rotation.apply(&unit), norm.max(1e-10))
378 }
379
380 pub fn codes_bytes(&self) -> usize {
382 self.ids.len() * 8 + self.packed.len() * 8 + self.cos_lut.len() * 4
384 }
385
386 pub fn rotation(&self) -> &RandomRotation {
387 &self.rotation
388 }
389
390 pub fn codes_materialised(&self) -> Vec<(usize, BinaryCode)> {
395 (0..self.ids.len())
396 .map(|i| {
397 let s = i * self.n_words;
398 let words = self.packed[s..s + self.n_words].to_vec();
399 (
400 self.ids[i] as usize,
401 BinaryCode {
402 words,
403 norm: self.norms[i],
404 dim: self.dim,
405 },
406 )
407 })
408 .collect()
409 }
410
411 pub fn ids(&self) -> &[u32] {
413 &self.ids
414 }
415 pub fn norms(&self) -> &[f32] {
416 &self.norms
417 }
418 pub fn packed(&self) -> &[u64] {
419 &self.packed
420 }
421 pub fn n_words(&self) -> usize {
422 self.n_words
423 }
424 pub fn cos_lut(&self) -> &[f32] {
425 &self.cos_lut
426 }
427
428 #[inline]
432 pub(crate) fn symmetric_scan_topk(
433 &self,
434 q_packed: &[u64],
435 q_norm: f32,
436 k: usize,
437 ) -> Vec<(u32, u32, f32)> {
438 let n = self.ids.len();
440 let mut top = TopK::new(k.min(n));
441 let q_sq = q_norm * q_norm;
442 let lut = &self.cos_lut;
443
444 let mut agree = vec![0u32; n];
448 crate::scan::scan(
449 &self.packed,
450 self.n_words,
451 n,
452 q_packed,
453 self.last_word_mask,
454 &mut agree,
455 );
456
457 for i in 0..n {
460 let est_cos = unsafe { *lut.get_unchecked(*agree.get_unchecked(i) as usize) };
463 let x_norm = self.norms[i];
464 let est_ip = q_norm * x_norm * est_cos;
465 let score = q_sq + x_norm * x_norm - 2.0 * est_ip;
466 top.push_raw(self.ids[i] as usize, score, i);
467 }
468 top.into_sorted_with_pos()
469 }
470}
471
472impl AnnIndex for RabitqIndex {
473 fn add(&mut self, id: usize, vector: Vec<f32>) -> Result<()> {
474 if vector.len() != self.dim {
475 return Err(RabitqError::DimensionMismatch {
476 expected: self.dim,
477 actual: vector.len(),
478 });
479 }
480 let norm: f32 = vector.iter().map(|&x| x * x).sum::<f32>().sqrt();
482 let mut unit = vector;
483 normalize_inplace(&mut unit);
484 let rotated = self.rotation.apply(&unit);
485 let start = self.packed.len();
486 self.packed.resize(start + self.n_words, 0);
487 let slot = &mut self.packed[start..start + self.n_words];
488 for (i, &v) in rotated.iter().enumerate() {
489 if v >= 0.0 {
490 slot[i / 64] |= 1u64 << (63 - (i % 64));
491 }
492 }
493 self.ids.push(id as u32);
494 self.norms.push(norm);
495 Ok(())
496 }
497
498 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
499 if self.ids.is_empty() {
500 return Err(RabitqError::EmptyIndex);
501 }
502 if query.len() != self.dim {
503 return Err(RabitqError::DimensionMismatch {
504 expected: self.dim,
505 actual: query.len(),
506 });
507 }
508 let (q_packed, q_norm) = self.encode_query_packed(query);
509 let results = self.symmetric_scan_topk(&q_packed, q_norm, k);
510 Ok(results
511 .into_iter()
512 .map(|(_, id, score)| SearchResult {
513 id: id as usize,
514 score,
515 })
516 .collect())
517 }
518
519 fn len(&self) -> usize {
520 self.ids.len()
521 }
522 fn dim(&self) -> usize {
523 self.dim
524 }
525 fn memory_bytes(&self) -> usize {
526 self.rotation.bytes() + self.codes_bytes()
527 }
528}
529
530pub struct RabitqPlusIndex {
543 inner: RabitqIndex,
544 originals_flat: Vec<f32>,
547 rerank_factor: usize,
548}
549
550impl RabitqPlusIndex {
551 #[inline]
552 fn original(&self, pos: usize) -> &[f32] {
553 let dim = self.inner.dim;
554 &self.originals_flat[pos * dim..(pos + 1) * dim]
555 }
556
557 pub fn external_ids(&self) -> &[u32] {
567 self.inner.ids()
568 }
569
570 pub fn ids_u64(&self) -> Vec<u64> {
574 self.inner.ids().iter().map(|&id| id as u64).collect()
575 }
576
577 pub fn export_items(&self) -> Vec<(usize, Vec<f32>)> {
590 let dim = self.inner.dim;
591 let n = self.inner.ids.len();
592 (0..n)
593 .map(|pos| {
594 (
595 pos,
596 self.originals_flat[pos * dim..(pos + 1) * dim].to_vec(),
597 )
598 })
599 .collect()
600 }
601}
602
603impl RabitqPlusIndex {
604 pub fn new(dim: usize, seed: u64, rerank_factor: usize) -> Self {
607 Self::new_with_rotation(dim, seed, rerank_factor, RandomRotationKind::HaarDense)
608 }
609
610 pub fn new_with_rotation(
613 dim: usize,
614 seed: u64,
615 rerank_factor: usize,
616 kind: RandomRotationKind,
617 ) -> Self {
618 Self {
619 inner: RabitqIndex::new_with_rotation(dim, seed, kind),
620 originals_flat: Vec::new(),
621 rerank_factor: rerank_factor.max(1),
622 }
623 }
624
625 pub fn rerank_factor(&self) -> usize {
626 self.rerank_factor
627 }
628 pub fn set_rerank_factor(&mut self, f: usize) {
629 self.rerank_factor = f.max(1);
630 }
631
632 pub fn from_vectors_parallel(
644 dim: usize,
645 seed: u64,
646 rerank_factor: usize,
647 items: Vec<(usize, Vec<f32>)>,
648 ) -> Result<Self> {
649 Self::from_vectors_parallel_with_rotation(
650 dim,
651 seed,
652 rerank_factor,
653 RandomRotationKind::HaarDense,
654 items,
655 )
656 }
657
658 pub fn from_vectors_parallel_with_rotation(
662 dim: usize,
663 seed: u64,
664 rerank_factor: usize,
665 kind: RandomRotationKind,
666 items: Vec<(usize, Vec<f32>)>,
667 ) -> Result<Self> {
668 use rayon::prelude::*;
669 let mut out = Self::new_with_rotation(dim, seed, rerank_factor, kind);
670 for (_, v) in &items {
671 if v.len() != dim {
672 return Err(RabitqError::DimensionMismatch {
673 expected: dim,
674 actual: v.len(),
675 });
676 }
677 }
678 let encoded: Vec<(usize, Vec<u64>, f32, Vec<f32>)> = items
682 .into_par_iter()
683 .map(|(id, v)| {
684 let (packed, _) = out.inner.encode_query_packed(&v);
685 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
686 (id, packed, norm, v)
687 })
688 .collect();
689 let n = encoded.len();
692 let n_words = out.inner.n_words;
693 out.inner.packed.reserve(n * n_words);
694 out.inner.ids.reserve(n);
695 out.inner.norms.reserve(n);
696 out.originals_flat.reserve(n * dim);
697 for (id, packed, norm, v) in encoded {
698 debug_assert_eq!(packed.len(), n_words);
699 debug_assert_eq!(v.len(), dim);
700 out.inner.packed.extend_from_slice(&packed);
701 out.inner.ids.push(id as u32);
702 out.inner.norms.push(norm);
703 out.originals_flat.extend_from_slice(&v);
704 }
705 Ok(out)
706 }
707
708 pub fn search_with_rerank(
719 &self,
720 query: &[f32],
721 k: usize,
722 rerank_factor: usize,
723 ) -> Result<Vec<SearchResult>> {
724 if self.inner.ids.is_empty() {
725 return Err(RabitqError::EmptyIndex);
726 }
727 if query.len() != self.inner.dim {
728 return Err(RabitqError::DimensionMismatch {
729 expected: self.inner.dim,
730 actual: query.len(),
731 });
732 }
733 let rf = rerank_factor.max(1);
734 let n = self.inner.ids.len();
735 let candidates = k.saturating_mul(rf).max(k).min(n);
736
737 let (q_packed, q_norm) = self.inner.encode_query_packed(query);
738 let cand = self
739 .inner
740 .symmetric_scan_topk(&q_packed, q_norm, candidates);
741
742 let k_eff = k.min(cand.len());
743 let mut top = TopK::new(k_eff);
744 for (pos, id, _score) in &cand {
745 let v = self.original(*pos as usize);
746 top.push(*id as usize, sq_l2(query, v));
747 }
748 Ok(top.into_sorted_asc())
749 }
750}
751
752impl AnnIndex for RabitqPlusIndex {
753 fn add(&mut self, id: usize, vector: Vec<f32>) -> Result<()> {
754 let dim = self.inner.dim;
755 if vector.len() != dim {
756 return Err(RabitqError::DimensionMismatch {
757 expected: dim,
758 actual: vector.len(),
759 });
760 }
761 self.originals_flat.extend_from_slice(&vector);
765 self.inner.add(id, vector)?;
766 Ok(())
767 }
768
769 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
770 if self.inner.ids.is_empty() {
771 return Err(RabitqError::EmptyIndex);
772 }
773 if query.len() != self.inner.dim {
774 return Err(RabitqError::DimensionMismatch {
775 expected: self.inner.dim,
776 actual: query.len(),
777 });
778 }
779 let n = self.inner.ids.len();
780 let candidates = k.saturating_mul(self.rerank_factor).max(k).min(n);
781
782 let (q_packed, q_norm) = self.inner.encode_query_packed(query);
784 let cand = self
785 .inner
786 .symmetric_scan_topk(&q_packed, q_norm, candidates);
787
788 let k_eff = k.min(cand.len());
792 let mut top = TopK::new(k_eff);
793 for (pos, id, _score) in &cand {
794 let v = self.original(*pos as usize);
795 top.push(*id as usize, sq_l2(query, v));
796 }
797 Ok(top.into_sorted_asc())
798 }
799
800 fn len(&self) -> usize {
801 self.inner.len()
802 }
803 fn dim(&self) -> usize {
804 self.inner.dim()
805 }
806 fn memory_bytes(&self) -> usize {
807 self.inner.memory_bytes() + 24 + self.originals_flat.len() * 4
810 }
811}
812
813pub struct RabitqAsymIndex {
819 inner: RabitqIndex,
820 originals: Vec<Vec<f32>>, rerank_factor: usize,
822 store_originals: bool,
823}
824
825impl RabitqAsymIndex {
826 pub fn new(dim: usize, seed: u64, rerank_factor: usize) -> Self {
829 let rf = rerank_factor.max(1);
830 Self {
831 inner: RabitqIndex::new(dim, seed),
832 originals: Vec::new(),
833 rerank_factor: rf,
834 store_originals: rf > 1,
835 }
836 }
837}
838
839impl AnnIndex for RabitqAsymIndex {
840 fn add(&mut self, id: usize, vector: Vec<f32>) -> Result<()> {
841 if self.store_originals {
842 self.originals.push(vector.clone());
843 }
844 self.inner.add(id, vector)
845 }
846
847 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
848 if self.inner.ids.is_empty() {
849 return Err(RabitqError::EmptyIndex);
850 }
851 if query.len() != self.inner.dim {
852 return Err(RabitqError::DimensionMismatch {
853 expected: self.inner.dim,
854 actual: query.len(),
855 });
856 }
857 let n = self.inner.ids.len();
858 let candidates = k.saturating_mul(self.rerank_factor).max(k).min(n);
859
860 let (q_rot_unit, q_norm) = self.inner.prepare_query_f32(query);
861
862 let d = self.inner.dim;
866 let n_words = self.inner.n_words;
867 let inv_sqrt_d = 1.0 / (d as f32).sqrt();
868 let q_sq = q_norm * q_norm;
869
870 let mut top_cand = TopK::new(candidates);
871 for i in 0..n {
872 let base = i * n_words;
873 let slot = &self.inner.packed[base..base + n_words];
874 let mut ip = 0.0f32;
875 for (idx, &q_i) in q_rot_unit.iter().enumerate() {
876 let bit_set = (slot[idx / 64] >> (63 - (idx % 64))) & 1 == 1;
877 ip += if bit_set { q_i } else { -q_i };
878 }
879 let unit_ip = ip * inv_sqrt_d;
880 let x_norm = self.inner.norms[i];
881 let est_ip = q_norm * x_norm * unit_ip;
882 let score = q_sq + x_norm * x_norm - 2.0 * est_ip;
883 top_cand.push_raw(self.inner.ids[i] as usize, score, i);
884 }
885 let cand = top_cand.into_sorted_with_pos();
886
887 if self.rerank_factor <= 1 || !self.store_originals {
888 let k_eff = k.min(cand.len());
889 let mut out: Vec<SearchResult> = cand
890 .into_iter()
891 .take(k_eff)
892 .map(|(_, id, score)| SearchResult {
893 id: id as usize,
894 score,
895 })
896 .collect();
897 out.sort_unstable_by(|a, b| cmp_score_asc(a.score, b.score));
898 return Ok(out);
899 }
900
901 let k_eff = k.min(cand.len());
902 let mut top = TopK::new(k_eff);
903 for (pos, id, _) in &cand {
904 let v = &self.originals[*pos as usize];
905 top.push(*id as usize, sq_l2(query, v));
906 }
907 Ok(top.into_sorted_asc())
908 }
909
910 fn len(&self) -> usize {
911 self.inner.len()
912 }
913 fn dim(&self) -> usize {
914 self.inner.dim()
915 }
916 fn memory_bytes(&self) -> usize {
917 let mut b = self.inner.memory_bytes();
918 if self.store_originals {
919 b += self.originals.len() * (self.inner.dim * 4 + 24);
920 }
921 b
922 }
923}
924
925#[cfg(test)]
926mod tests {
927 use super::*;
928
929 fn make_dataset(n: usize, d: usize, seed: u64) -> Vec<(usize, Vec<f32>)> {
931 use rand::{Rng as _, SeedableRng as _};
932 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
933 (0..n)
934 .map(|i| {
935 let v: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
936 (i, v)
937 })
938 .collect()
939 }
940
941 fn make_clustered(n: usize, d: usize, n_clusters: usize, seed: u64) -> Vec<Vec<f32>> {
943 use rand::{Rng as _, SeedableRng as _};
944 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
945 let centroids: Vec<Vec<f32>> = (0..n_clusters)
946 .map(|_| {
947 (0..d)
948 .map(|_| rng.gen::<f32>() * 4.0 - 2.0)
949 .collect::<Vec<_>>()
950 })
951 .collect();
952 (0..n)
953 .map(|_| {
954 let c = ¢roids[rng.gen_range(0..n_clusters)];
955 c.iter()
956 .map(|&x| x + (rng.gen::<f32>() - 0.5) * 0.3)
957 .collect()
958 })
959 .collect()
960 }
961
962 #[test]
963 fn flat_f32_returns_exact_nn() {
964 let d = 64;
965 let mut idx = FlatF32Index::new(d);
966 let data = make_dataset(200, d, 1);
967 for (id, v) in &data {
968 idx.add(*id, v.clone()).unwrap();
969 }
970 let query = &data[7].1;
971 let results = idx.search(query, 1).unwrap();
972 assert_eq!(results[0].id, 7);
973 assert!(results[0].score < 1e-6);
974 }
975
976 #[test]
980 fn rabitq_recall_above_random() {
981 let d = 128;
982 let n = 1000;
983 let nq = 100;
984 let all_data = make_clustered(n + nq, d, 20, 42);
985 let (db_vecs, query_vecs) = all_data.split_at(n);
986 let data: Vec<(usize, Vec<f32>)> = db_vecs.iter().cloned().enumerate().collect();
987 let queries: Vec<Vec<f32>> = query_vecs.to_vec();
988
989 let mut exact = FlatF32Index::new(d);
990 let mut idx = RabitqIndex::new(d, 42);
991 for (id, v) in &data {
992 exact.add(*id, v.clone()).unwrap();
993 idx.add(*id, v.clone()).unwrap();
994 }
995
996 let k = 10;
997 let mut hits = 0usize;
998 for q in &queries {
999 let e: std::collections::HashSet<usize> =
1000 exact.search(q, k).unwrap().iter().map(|r| r.id).collect();
1001 hits += idx
1002 .search(q, k)
1003 .unwrap()
1004 .iter()
1005 .filter(|r| e.contains(&r.id))
1006 .count();
1007 }
1008 let recall = hits as f64 / (nq * k) as f64;
1009 assert!(
1011 recall > 0.20,
1012 "recall@10={:.1}% — not above 20 % baseline",
1013 recall * 100.0
1014 );
1015 }
1016
1017 #[test]
1018 fn rabitq_plus_recall_above_90pct() {
1019 let d = 128;
1020 let n = 1000;
1021 let nq = 100;
1022 let all_data = make_clustered(n + nq, d, 20, 55);
1023 let (db_vecs, query_vecs) = all_data.split_at(n);
1024 let data: Vec<(usize, Vec<f32>)> = db_vecs.iter().cloned().enumerate().collect();
1025 let queries: Vec<Vec<f32>> = query_vecs.to_vec();
1026
1027 let mut exact = FlatF32Index::new(d);
1028 let mut idx = RabitqPlusIndex::new(d, 55, 5);
1029 for (id, v) in &data {
1030 exact.add(*id, v.clone()).unwrap();
1031 idx.add(*id, v.clone()).unwrap();
1032 }
1033 let k = 10;
1034 let mut hits = 0usize;
1035 for q in &queries {
1036 let e: std::collections::HashSet<usize> =
1037 exact.search(q, k).unwrap().iter().map(|r| r.id).collect();
1038 hits += idx
1039 .search(q, k)
1040 .unwrap()
1041 .iter()
1042 .filter(|r| e.contains(&r.id))
1043 .count();
1044 }
1045 let recall = hits as f64 / (nq * k) as f64;
1046 assert!(
1047 recall > 0.90,
1048 "rerank×5 recall@10={:.1}% < 90 %",
1049 recall * 100.0
1050 );
1051 }
1052
1053 #[test]
1057 fn asymmetric_meets_or_beats_symmetric() {
1058 let d = 128;
1059 let n = 1000;
1060 let nq = 100;
1061 let all_data = make_clustered(n + nq, d, 20, 77);
1062 let (db_vecs, query_vecs) = all_data.split_at(n);
1063 let data: Vec<(usize, Vec<f32>)> = db_vecs.iter().cloned().enumerate().collect();
1064 let queries: Vec<Vec<f32>> = query_vecs.to_vec();
1065
1066 let mut exact = FlatF32Index::new(d);
1067 let mut sym = RabitqIndex::new(d, 77);
1068 let mut asym = RabitqAsymIndex::new(d, 77, 1);
1069 for (id, v) in &data {
1070 exact.add(*id, v.clone()).unwrap();
1071 sym.add(*id, v.clone()).unwrap();
1072 asym.add(*id, v.clone()).unwrap();
1073 }
1074 let k = 10;
1075 let mut sh = 0usize;
1076 let mut ah = 0usize;
1077 for q in &queries {
1078 let e: std::collections::HashSet<usize> =
1079 exact.search(q, k).unwrap().iter().map(|r| r.id).collect();
1080 sh += sym
1081 .search(q, k)
1082 .unwrap()
1083 .iter()
1084 .filter(|r| e.contains(&r.id))
1085 .count();
1086 ah += asym
1087 .search(q, k)
1088 .unwrap()
1089 .iter()
1090 .filter(|r| e.contains(&r.id))
1091 .count();
1092 }
1093 let sr = sh as f64 / (nq * k) as f64;
1094 let ar = ah as f64 / (nq * k) as f64;
1095 eprintln!("sym={:.1}% asym={:.1}%", sr * 100.0, ar * 100.0);
1096 assert!(ar + 0.02 >= sr, "asymmetric regressed vs symmetric");
1098 }
1099
1100 #[test]
1104 fn recall_holds_at_non_aligned_dim() {
1105 let d = 100;
1106 let n = 500;
1107 let nq = 50;
1108 let all_data = make_clustered(n + nq, d, 15, 17);
1109 let (db_vecs, query_vecs) = all_data.split_at(n);
1110 let data: Vec<(usize, Vec<f32>)> = db_vecs.iter().cloned().enumerate().collect();
1111 let queries: Vec<Vec<f32>> = query_vecs.to_vec();
1112
1113 let mut exact = FlatF32Index::new(d);
1114 let mut idx = RabitqPlusIndex::new(d, 17, 5);
1115 for (id, v) in &data {
1116 exact.add(*id, v.clone()).unwrap();
1117 idx.add(*id, v.clone()).unwrap();
1118 }
1119 let k = 10;
1120 let mut hits = 0usize;
1121 for q in &queries {
1122 let e: std::collections::HashSet<usize> =
1123 exact.search(q, k).unwrap().iter().map(|r| r.id).collect();
1124 hits += idx
1125 .search(q, k)
1126 .unwrap()
1127 .iter()
1128 .filter(|r| e.contains(&r.id))
1129 .count();
1130 }
1131 let r = hits as f64 / (nq * k) as f64;
1132 assert!(r > 0.80, "D=100 rerank×5 recall={:.1}% < 80 %", r * 100.0);
1133 }
1134
1135 #[test]
1136 fn nan_query_does_not_panic() {
1137 let d = 64;
1138 let mut idx = RabitqIndex::new(d, 42);
1139 let data = make_dataset(100, d, 3);
1140 for (id, v) in &data {
1141 idx.add(*id, v.clone()).unwrap();
1142 }
1143 let mut q = data[0].1.clone();
1144 q[5] = f32::NAN;
1145 let _ = idx.search(&q, 5);
1147 }
1148
1149 #[test]
1150 fn memory_accounting_is_honest() {
1151 let d = 256;
1152 let n = 1000;
1153 let data = make_dataset(n, d, 0);
1154 let mut flat = FlatF32Index::new(d);
1155 let mut rq = RabitqIndex::new(d, 0);
1156 let mut rq_plus = RabitqPlusIndex::new(d, 0, 5);
1157 for (id, v) in &data {
1158 flat.add(*id, v.clone()).unwrap();
1159 rq.add(*id, v.clone()).unwrap();
1160 rq_plus.add(*id, v.clone()).unwrap();
1161 }
1162 let f = flat.memory_bytes();
1163 let rqb = rq.memory_bytes();
1164 let rqpb = rq_plus.memory_bytes();
1165 assert!(rqb < f, "RabitqIndex {rqb} should be < Flat {f}");
1167 assert!(
1169 rqpb > f,
1170 "RabitqPlusIndex {rqpb} should be > Flat {f} (rerank stores both)"
1171 );
1172 }
1173
1174 #[test]
1175 fn heap_topk_is_sorted_ascending() {
1176 let d = 64;
1177 let mut idx = FlatF32Index::new(d);
1178 let data = make_dataset(50, d, 2);
1179 for (id, v) in &data {
1180 idx.add(*id, v.clone()).unwrap();
1181 }
1182 let r = idx.search(&data[0].1, 10).unwrap();
1183 assert_eq!(r.len(), 10);
1184 for w in r.windows(2) {
1185 assert!(w[0].score <= w[1].score, "top-k not ascending: {:?}", r);
1186 }
1187 }
1188
1189 #[test]
1195 fn hadamard_index_builds_and_searches() {
1196 let d = 128;
1197 let n = 500;
1198 let nq = 10;
1199 let all_data = make_clustered(n + nq, d, 12, 2026);
1200 let (db_vecs, query_vecs) = all_data.split_at(n);
1201 let data: Vec<(usize, Vec<f32>)> = db_vecs.iter().cloned().enumerate().collect();
1202
1203 let idx = RabitqPlusIndex::from_vectors_parallel_with_rotation(
1204 d,
1205 2026,
1206 5,
1207 RandomRotationKind::HadamardSigned,
1208 data,
1209 )
1210 .expect("bulk-build with Hadamard rotation");
1211 assert_eq!(idx.len(), n);
1212 assert_eq!(idx.dim(), d);
1213
1214 let k = 10;
1215 for q in query_vecs {
1216 let res = idx.search(q, k).unwrap();
1217 assert_eq!(res.len(), k, "expected {k} results, got {}", res.len());
1218 for r in &res {
1219 assert!(
1220 r.score.is_finite(),
1221 "Hadamard-rotated result has non-finite score: {r:?}",
1222 );
1223 }
1224 }
1225 }
1226
1227 #[test]
1233 fn hadamard_recall_at_10_within_5pct_of_haar() {
1234 let d = 128;
1235 let n = 500;
1236 let nq = 50;
1237 let all_data = make_clustered(n + nq, d, 16, 131);
1238 let (db_vecs, query_vecs) = all_data.split_at(n);
1239 let data: Vec<(usize, Vec<f32>)> = db_vecs.iter().cloned().enumerate().collect();
1240
1241 let mut exact = FlatF32Index::new(d);
1243 for (id, v) in &data {
1244 exact.add(*id, v.clone()).unwrap();
1245 }
1246
1247 let seed = 131_u64;
1250 let rerank = 20;
1251 let mut haar =
1252 RabitqPlusIndex::new_with_rotation(d, seed, rerank, RandomRotationKind::HaarDense);
1253 let mut had =
1254 RabitqPlusIndex::new_with_rotation(d, seed, rerank, RandomRotationKind::HadamardSigned);
1255 for (id, v) in &data {
1256 haar.add(*id, v.clone()).unwrap();
1257 had.add(*id, v.clone()).unwrap();
1258 }
1259
1260 let k = 10;
1261 let mut haar_hits = 0usize;
1262 let mut had_hits = 0usize;
1263 for q in query_vecs {
1264 let gt: std::collections::HashSet<usize> =
1265 exact.search(q, k).unwrap().iter().map(|r| r.id).collect();
1266 haar_hits += haar
1267 .search(q, k)
1268 .unwrap()
1269 .iter()
1270 .filter(|r| gt.contains(&r.id))
1271 .count();
1272 had_hits += had
1273 .search(q, k)
1274 .unwrap()
1275 .iter()
1276 .filter(|r| gt.contains(&r.id))
1277 .count();
1278 }
1279 let haar_recall = haar_hits as f64 / (nq * k) as f64;
1280 let had_recall = had_hits as f64 / (nq * k) as f64;
1281 eprintln!(
1282 "hadamard_recall_at_10_within_5pct_of_haar: haar={:.3} had={:.3}",
1283 haar_recall, had_recall
1284 );
1285 assert!(
1288 had_recall >= 0.85,
1289 "Hadamard recall@10={had_recall:.3} < 0.85 (haar={haar_recall:.3})",
1290 );
1291 }
1292
1293 #[test]
1300 fn hadamard_rotation_memory_smaller_than_haar() {
1301 let d = 128;
1302 let haar = RabitqIndex::new_with_rotation(d, 0, RandomRotationKind::HaarDense);
1307 let had = RabitqIndex::new_with_rotation(d, 0, RandomRotationKind::HadamardSigned);
1308
1309 let haar_bytes = haar.memory_bytes();
1310 let had_bytes = had.memory_bytes();
1311 eprintln!(
1312 "hadamard_rotation_memory_smaller_than_haar: haar={haar_bytes}B had={had_bytes}B ratio={:.1}x",
1313 haar_bytes as f64 / had_bytes as f64,
1314 );
1315 assert!(
1316 had_bytes * 30 <= haar_bytes,
1317 "Hadamard memory={had_bytes} vs Haar={haar_bytes} — expected ≥ 30× reduction",
1318 );
1319 }
1320
1321 #[test]
1328 fn export_items_roundtrip_via_from_vectors_parallel() {
1329 let d = 16;
1330 let n = 100;
1331 let seed = 20_260_423_u64;
1332 let rerank = 4;
1333 let kind = RandomRotationKind::HaarDense;
1334
1335 let data = make_dataset(n, d, seed);
1336
1337 let mut src = RabitqPlusIndex::new_with_rotation(d, seed, rerank, kind);
1340 for (id, v) in &data {
1341 src.add(*id, v.clone()).unwrap();
1342 }
1343 assert_eq!(src.len(), n);
1344
1345 let items = src.export_items();
1346 assert_eq!(items.len(), n);
1347 for (pos, row) in &items {
1348 assert_eq!(row.len(), d, "row {pos} wrong dim");
1349 }
1350
1351 let rebuilt =
1352 RabitqPlusIndex::from_vectors_parallel_with_rotation(d, seed, rerank, kind, items)
1353 .expect("rebuild from export_items");
1354 assert_eq!(rebuilt.len(), n);
1355 assert_eq!(rebuilt.dim(), d);
1356
1357 let queries = make_dataset(5, d, seed ^ 0xDEAD_BEEF);
1359 let k = 10;
1360 for (_, q) in &queries {
1361 let a = src.search(q, k).unwrap();
1362 let b = rebuilt.search(q, k).unwrap();
1363 assert_eq!(a.len(), b.len(), "result count differs");
1364 for (ra, rb) in a.iter().zip(b.iter()) {
1365 assert_eq!(ra.id, rb.id, "id mismatch on query");
1366 assert_eq!(
1367 ra.score.to_bits(),
1368 rb.score.to_bits(),
1369 "score bits differ for id={}",
1370 ra.id,
1371 );
1372 }
1373 }
1374 }
1375}