ruvector_rabitq/quantize.rs
1//! Bit-packing, XNOR-popcount, and the two RaBitQ distance estimators:
2//!
3//! 1. **Symmetric Charikar-style angular estimator** — both query and database
4//! are 1-bit. Derived from hyperplane-LSH collision probability:
5//! E[B/D] = 1 − θ/π
6//! This is what the shipped crate had at commit `f2dbb6efb`.
7//!
8//! 2. **Asymmetric RaBitQ-2024 inner-product estimator** — query stays in f32,
9//! database is 1-bit `b_i ∈ {−1/√D, +1/√D}`. Inner product is reconstructed
10//! by summing the rotated query's components with signs, then rescaled by
11//! a precomputed factor derived from the stored unit-sphere inner-product
12//! bias. Unbiased for Haar-uniform rotations with O(1/√D) variance.
13//!
14//! The asymmetric path closes the gap between this crate's estimator and the
15//! SIGMOD 2024 paper (Gao & Long) — the symmetric path remains for
16//! apples-to-apples comparison against naive 1-bit codes.
17//!
18//! ## Bit packing
19//!
20//! Each dimension is one bit: 1 if the rotated value ≥ 0, else 0. Bits are
21//! packed MSB-first into u64 words. When `D % 64 != 0` the last word carries
22//! `64·n_words − D` padding bits that are zero in every code; XNOR-popcount
23//! must mask those bits off before counting, otherwise padding bits always
24//! agree and the estimator is biased. `masked_xnor_popcount` handles this
25//! correctly; `xnor_popcount` is retained for the aligned case.
26
27/// A packed binary code representing one vector (D bits).
28#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
29pub struct BinaryCode {
30 /// Packed u64 words (ceil(D/64) words).
31 pub words: Vec<u64>,
32 /// Original L2 norm before normalisation (needed for the IP estimator).
33 pub norm: f32,
34 /// Number of dimensions.
35 pub dim: usize,
36}
37
38impl BinaryCode {
39 /// Encode a (possibly rotated) vector into a binary code.
40 ///
41 /// `norm` should be the L2 norm of the *pre-rotation* vector so the estimator
42 /// can rescale correctly.
43 pub fn encode(rotated: &[f32], norm: f32) -> Self {
44 let dim = rotated.len();
45 let n_words = (dim + 63) / 64;
46 let mut words = vec![0u64; n_words];
47 for (i, &v) in rotated.iter().enumerate() {
48 if v >= 0.0 {
49 words[i / 64] |= 1u64 << (63 - (i % 64));
50 }
51 }
52 Self { words, norm, dim }
53 }
54
55 /// Raw XNOR-popcount across all stored bits. **Do not use when
56 /// `D % 64 != 0`** — the padding bits in the last word are zero in every
57 /// code and XNOR-popcount counts them as matches, biasing the estimator.
58 /// Retained as a fast path for the aligned case (D multiple of 64).
59 #[inline]
60 pub fn xnor_popcount(&self, other: &Self) -> u32 {
61 debug_assert_eq!(self.words.len(), other.words.len());
62 self.words
63 .iter()
64 .zip(other.words.iter())
65 .map(|(&a, &b)| (!(a ^ b)).count_ones())
66 .sum()
67 }
68
69 /// Padding-safe XNOR-popcount. Masks the trailing
70 /// `64·n_words − D` bits of the last word so padding zeros don't inflate
71 /// the agreement count. Correct at any `D ≥ 1`; same cost as the raw
72 /// version up to one extra AND on the last word.
73 #[inline]
74 pub fn masked_xnor_popcount(&self, other: &Self) -> u32 {
75 debug_assert_eq!(self.words.len(), other.words.len());
76 debug_assert_eq!(self.dim, other.dim);
77 let n_words = self.words.len();
78 if n_words == 0 {
79 return 0;
80 }
81 let mut sum: u32 = 0;
82 for i in 0..n_words - 1 {
83 sum += (!(self.words[i] ^ other.words[i])).count_ones();
84 }
85 // Last word: mask off the padding bits that were never written.
86 let valid_bits = self.dim - 64 * (n_words - 1);
87 let mask: u64 = if valid_bits == 64 {
88 !0u64
89 } else {
90 // Keep the top `valid_bits` MSBs (because we packed MSB-first).
91 !0u64 << (64 - valid_bits)
92 };
93 let last = !(self.words[n_words - 1] ^ other.words[n_words - 1]) & mask;
94 sum += last.count_ones();
95 sum
96 }
97
98 /// **Symmetric** angular estimator (Charikar-style) — both operands are
99 /// 1-bit codes of rotated unit vectors.
100 ///
101 /// For normalized database x̂ (`self.norm` holds the original ‖x‖) and
102 /// normalized query q̂ (`query_code.norm` holds the original ‖q‖):
103 ///
104 /// E[B/D] = 1 − θ/π where θ = arccos(⟨x̂, q̂⟩)
105 /// ⟹ est cos(θ) = cos(π · (1 − B/D))
106 /// ⟹ est ⟨q, x⟩ = ‖q‖ · ‖x‖ · est cos(θ)
107 ///
108 /// Returns estimated squared-L2: ‖q − x‖² = ‖q‖² + ‖x‖² − 2⟨q, x⟩.
109 #[inline]
110 pub fn estimated_sq_distance(&self, query_code: &Self) -> f32 {
111 use std::f32::consts::PI;
112 let d = self.dim as f32;
113 let agreement = self.masked_xnor_popcount(query_code) as f32;
114 let est_cos = (PI * (1.0 - agreement / d)).cos();
115 let est_ip = self.norm * query_code.norm * est_cos;
116 let q_sq = query_code.norm * query_code.norm;
117 q_sq + self.norm * self.norm - 2.0 * est_ip
118 }
119
120 /// **Asymmetric** inner-product estimator (RaBitQ-style, keeps the query
121 /// in f32). More accurate than the symmetric path, at the cost of
122 /// O(D) arithmetic per candidate instead of O(D/64) popcount.
123 ///
124 /// Given the rotated-unit query `q_rot` (‖q_rot‖ = 1) and the stored 1-bit
125 /// code `b_x` ∈ {−1/√D, +1/√D}ᴰ, the unbiased inner-product estimate is:
126 ///
127 /// ⟨q̂_rot, u_x⟩ ≈ (1/√D) · Σᵢ sign(x_rot,i) · q_rot,i
128 ///
129 /// where u_x is the rotated unit vector and `b_x,i = sign(x_rot,i)/√D`.
130 /// The unbiasing factor accounts for the concentration of
131 /// `Σ|q_rot,i|` on a Haar-uniform rotation of q (which preserves norm).
132 ///
133 /// Returns estimated squared-L2: `‖q − x‖² = ‖q‖² + ‖x‖² − 2‖q‖·‖x‖·ŝ`
134 /// where `ŝ = ⟨q̂_rot, u_x⟩` is the unit-sphere IP estimate above.
135 ///
136 /// `q_rotated` must be length `self.dim`; caller pre-normalises and
137 /// pre-rotates the query once per search (amortised across n candidates).
138 #[inline]
139 pub fn estimated_sq_distance_asymmetric(&self, q_rotated_unit: &[f32], q_norm: f32) -> f32 {
140 debug_assert_eq!(q_rotated_unit.len(), self.dim);
141 let d = self.dim;
142 let inv_sqrt_d = 1.0 / (d as f32).sqrt();
143 // Σᵢ sign(x_rot,i) · q_rot,i without materialising signs: bit = 1
144 // means +1, bit = 0 means −1.
145 let mut ip = 0.0f32;
146 for (i, &q_i) in q_rotated_unit.iter().enumerate() {
147 let bit_set = (self.words[i / 64] >> (63 - (i % 64))) & 1 == 1;
148 ip += if bit_set { q_i } else { -q_i };
149 }
150 let unit_ip = ip * inv_sqrt_d;
151 let est_ip = q_norm * self.norm * unit_ip;
152 q_norm * q_norm + self.norm * self.norm - 2.0 * est_ip
153 }
154}
155
156/// Pack bits from a boolean slice into u64 words (for testing/utilities).
157pub fn pack_bits(bits: &[bool]) -> Vec<u64> {
158 let n_words = (bits.len() + 63) / 64;
159 let mut words = vec![0u64; n_words];
160 for (i, &b) in bits.iter().enumerate() {
161 if b {
162 words[i / 64] |= 1u64 << (63 - (i % 64));
163 }
164 }
165 words
166}
167
168/// Unpack u64 words back into a bool slice of length `dim`.
169pub fn unpack_bits(words: &[u64], dim: usize) -> Vec<bool> {
170 (0..dim)
171 .map(|i| words[i / 64] & (1u64 << (63 - (i % 64))) != 0)
172 .collect()
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn pack_unpack_roundtrip() {
181 let bits: Vec<bool> = (0..130).map(|i| i % 3 == 0).collect();
182 let words = pack_bits(&bits);
183 let unpacked = unpack_bits(&words, 130);
184 assert_eq!(bits, unpacked);
185 }
186
187 #[test]
188 fn xnor_self_is_all_ones() {
189 let v: Vec<f32> = (0..64)
190 .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
191 .collect();
192 let code = BinaryCode::encode(&v, 1.0);
193 let agreement = code.xnor_popcount(&code);
194 assert_eq!(
195 agreement, 64,
196 "self-agreement should be D=64, got {agreement}"
197 );
198 }
199
200 #[test]
201 fn xnor_opposite_is_zero() {
202 let v: Vec<f32> = (0..64)
203 .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
204 .collect();
205 let neg_v: Vec<f32> = v.iter().map(|&x| -x).collect();
206 let code = BinaryCode::encode(&v, 1.0);
207 let code_neg = BinaryCode::encode(&neg_v, 1.0);
208 let agreement = code.xnor_popcount(&code_neg);
209 assert_eq!(agreement, 0, "opposite vectors should have 0 agreement");
210 }
211
212 /// Bug surfaced by the deep review: at `D % 64 != 0` the padding bits in
213 /// the last word are zero in every code, so XNOR-popcount counts them as
214 /// matches. `masked_xnor_popcount` must not count padding.
215 #[test]
216 fn masked_popcount_handles_non_aligned_dim() {
217 // D=100 → 2 u64 words, 28 padding bits.
218 let v: Vec<f32> = (0..100)
219 .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
220 .collect();
221 let neg_v: Vec<f32> = v.iter().map(|&x| -x).collect();
222 let code = BinaryCode::encode(&v, 1.0);
223 let code_neg = BinaryCode::encode(&neg_v, 1.0);
224 // Raw would read 0 matches + 28 padding matches = 28 (wrong).
225 let raw = code.xnor_popcount(&code_neg);
226 assert_eq!(
227 raw, 28,
228 "raw xnor should count padding as matches (bug demo)"
229 );
230 // Masked must report 0 matches.
231 let masked = code.masked_xnor_popcount(&code_neg);
232 assert_eq!(
233 masked, 0,
234 "masked xnor must ignore padding bits; got {masked}"
235 );
236 // Self-compare: every real bit matches, padding is masked.
237 let self_masked = code.masked_xnor_popcount(&code);
238 assert_eq!(self_masked, 100);
239 }
240
241 #[test]
242 fn masked_popcount_matches_raw_when_aligned() {
243 // D=128 is 64-aligned, so masked == raw.
244 let v: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1).sin()).collect();
245 let w: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1).cos()).collect();
246 let ca = BinaryCode::encode(&v, 1.0);
247 let cb = BinaryCode::encode(&w, 1.0);
248 assert_eq!(ca.xnor_popcount(&cb), ca.masked_xnor_popcount(&cb));
249 }
250
251 #[test]
252 fn estimated_distance_self_is_near_zero() {
253 // A unit vector against itself should estimate distance ≈ 0.
254 let v: Vec<f32> = (0..128).map(|i| (i as f32 / 128.0).sin()).collect();
255 let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
256 let unit: Vec<f32> = v.iter().map(|&x| x / norm).collect();
257 let code = BinaryCode::encode(&unit, 1.0);
258 let est = code.estimated_sq_distance(&code);
259 // Symmetric Charikar estimator on the same code: cos(π·(1−D/D))=1 → est=0.
260 assert!(
261 est.abs() < 1e-5,
262 "self sq-distance estimate too large: {est}"
263 );
264 }
265
266 #[test]
267 fn asymmetric_matches_symmetric_in_sign() {
268 // The asymmetric IP estimator and the symmetric cos-angle estimator
269 // should agree on which of two candidates is closer (even when the
270 // magnitudes differ) — they encode the same angular signal.
271 use rand::{Rng as _, SeedableRng as _};
272 let mut rng = rand::rngs::StdRng::seed_from_u64(11);
273 let d = 128;
274 let q: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
275 let q_norm: f32 = q.iter().map(|&x| x * x).sum::<f32>().sqrt();
276 let q_unit: Vec<f32> = q.iter().map(|&x| x / q_norm).collect();
277 let qc = BinaryCode::encode(&q_unit, q_norm);
278
279 let near: Vec<f32> = q.iter().map(|&x| x + rng.gen::<f32>() * 0.1).collect();
280 let far: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
281
282 let encode_one = |v: &[f32]| {
283 let n: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
284 let u: Vec<f32> = v.iter().map(|&x| x / n).collect();
285 BinaryCode::encode(&u, n)
286 };
287 let cn = encode_one(&near);
288 let cf = encode_one(&far);
289
290 // Symmetric
291 let s_near = cn.estimated_sq_distance(&qc);
292 let s_far = cf.estimated_sq_distance(&qc);
293 // Asymmetric: the "rotated unit query" here is just q_unit (no
294 // rotation since we're testing the estimator math directly).
295 let a_near = cn.estimated_sq_distance_asymmetric(&q_unit, q_norm);
296 let a_far = cf.estimated_sq_distance_asymmetric(&q_unit, q_norm);
297 assert!(s_near < s_far, "symmetric ordering wrong");
298 assert!(a_near < a_far, "asymmetric ordering wrong");
299 }
300}