Skip to main content

ruvector_rabitq/
quantize.rs

1//! Bit-packing, XNOR-popcount, and the two RaBitQ distance estimators:
2//!
3//! 1. **Symmetric Charikar-style angular estimator** — both query and database
4//!    are 1-bit. Derived from hyperplane-LSH collision probability:
5//!        E[B/D] = 1 − θ/π
6//!    This is what the shipped crate had at commit `f2dbb6efb`.
7//!
8//! 2. **Asymmetric RaBitQ-2024 inner-product estimator** — query stays in f32,
9//!    database is 1-bit `b_i ∈ {−1/√D, +1/√D}`. Inner product is reconstructed
10//!    by summing the rotated query's components with signs, then rescaled by
11//!    a precomputed factor derived from the stored unit-sphere inner-product
12//!    bias. Unbiased for Haar-uniform rotations with O(1/√D) variance.
13//!
14//! The asymmetric path closes the gap between this crate's estimator and the
15//! SIGMOD 2024 paper (Gao & Long) — the symmetric path remains for
16//! apples-to-apples comparison against naive 1-bit codes.
17//!
18//! ## Bit packing
19//!
20//! Each dimension is one bit: 1 if the rotated value ≥ 0, else 0. Bits are
21//! packed MSB-first into u64 words. When `D % 64 != 0` the last word carries
22//! `64·n_words − D` padding bits that are zero in every code; XNOR-popcount
23//! must mask those bits off before counting, otherwise padding bits always
24//! agree and the estimator is biased. `masked_xnor_popcount` handles this
25//! correctly; `xnor_popcount` is retained for the aligned case.
26
27/// A packed binary code representing one vector (D bits).
28#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
29pub struct BinaryCode {
30    /// Packed u64 words (ceil(D/64) words).
31    pub words: Vec<u64>,
32    /// Original L2 norm before normalisation (needed for the IP estimator).
33    pub norm: f32,
34    /// Number of dimensions.
35    pub dim: usize,
36}
37
38impl BinaryCode {
39    /// Encode a (possibly rotated) vector into a binary code.
40    ///
41    /// `norm` should be the L2 norm of the *pre-rotation* vector so the estimator
42    /// can rescale correctly.
43    pub fn encode(rotated: &[f32], norm: f32) -> Self {
44        let dim = rotated.len();
45        let n_words = (dim + 63) / 64;
46        let mut words = vec![0u64; n_words];
47        for (i, &v) in rotated.iter().enumerate() {
48            if v >= 0.0 {
49                words[i / 64] |= 1u64 << (63 - (i % 64));
50            }
51        }
52        Self { words, norm, dim }
53    }
54
55    /// Raw XNOR-popcount across all stored bits. **Do not use when
56    /// `D % 64 != 0`** — the padding bits in the last word are zero in every
57    /// code and XNOR-popcount counts them as matches, biasing the estimator.
58    /// Retained as a fast path for the aligned case (D multiple of 64).
59    #[inline]
60    pub fn xnor_popcount(&self, other: &Self) -> u32 {
61        debug_assert_eq!(self.words.len(), other.words.len());
62        self.words
63            .iter()
64            .zip(other.words.iter())
65            .map(|(&a, &b)| (!(a ^ b)).count_ones())
66            .sum()
67    }
68
69    /// Padding-safe XNOR-popcount. Masks the trailing
70    /// `64·n_words − D` bits of the last word so padding zeros don't inflate
71    /// the agreement count. Correct at any `D ≥ 1`; same cost as the raw
72    /// version up to one extra AND on the last word.
73    #[inline]
74    pub fn masked_xnor_popcount(&self, other: &Self) -> u32 {
75        debug_assert_eq!(self.words.len(), other.words.len());
76        debug_assert_eq!(self.dim, other.dim);
77        let n_words = self.words.len();
78        if n_words == 0 {
79            return 0;
80        }
81        let mut sum: u32 = 0;
82        for i in 0..n_words - 1 {
83            sum += (!(self.words[i] ^ other.words[i])).count_ones();
84        }
85        // Last word: mask off the padding bits that were never written.
86        let valid_bits = self.dim - 64 * (n_words - 1);
87        let mask: u64 = if valid_bits == 64 {
88            !0u64
89        } else {
90            // Keep the top `valid_bits` MSBs (because we packed MSB-first).
91            !0u64 << (64 - valid_bits)
92        };
93        let last = !(self.words[n_words - 1] ^ other.words[n_words - 1]) & mask;
94        sum += last.count_ones();
95        sum
96    }
97
98    /// **Symmetric** angular estimator (Charikar-style) — both operands are
99    /// 1-bit codes of rotated unit vectors.
100    ///
101    /// For normalized database x̂ (`self.norm` holds the original ‖x‖) and
102    /// normalized query q̂ (`query_code.norm` holds the original ‖q‖):
103    ///
104    ///   E[B/D] = 1 − θ/π  where  θ = arccos(⟨x̂, q̂⟩)
105    ///   ⟹  est cos(θ) = cos(π · (1 − B/D))
106    ///   ⟹  est ⟨q, x⟩ = ‖q‖ · ‖x‖ · est cos(θ)
107    ///
108    /// Returns estimated squared-L2: ‖q − x‖² = ‖q‖² + ‖x‖² − 2⟨q, x⟩.
109    #[inline]
110    pub fn estimated_sq_distance(&self, query_code: &Self) -> f32 {
111        use std::f32::consts::PI;
112        let d = self.dim as f32;
113        let agreement = self.masked_xnor_popcount(query_code) as f32;
114        let est_cos = (PI * (1.0 - agreement / d)).cos();
115        let est_ip = self.norm * query_code.norm * est_cos;
116        let q_sq = query_code.norm * query_code.norm;
117        q_sq + self.norm * self.norm - 2.0 * est_ip
118    }
119
120    /// **Asymmetric** inner-product estimator (RaBitQ-style, keeps the query
121    /// in f32). More accurate than the symmetric path, at the cost of
122    /// O(D) arithmetic per candidate instead of O(D/64) popcount.
123    ///
124    /// Given the rotated-unit query `q_rot` (‖q_rot‖ = 1) and the stored 1-bit
125    /// code `b_x` ∈ {−1/√D, +1/√D}ᴰ, the unbiased inner-product estimate is:
126    ///
127    ///   ⟨q̂_rot, u_x⟩ ≈ (1/√D) · Σᵢ sign(x_rot,i) · q_rot,i
128    ///
129    /// where u_x is the rotated unit vector and `b_x,i = sign(x_rot,i)/√D`.
130    /// The unbiasing factor accounts for the concentration of
131    /// `Σ|q_rot,i|` on a Haar-uniform rotation of q (which preserves norm).
132    ///
133    /// Returns estimated squared-L2: `‖q − x‖² = ‖q‖² + ‖x‖² − 2‖q‖·‖x‖·ŝ`
134    /// where `ŝ = ⟨q̂_rot, u_x⟩` is the unit-sphere IP estimate above.
135    ///
136    /// `q_rotated` must be length `self.dim`; caller pre-normalises and
137    /// pre-rotates the query once per search (amortised across n candidates).
138    #[inline]
139    pub fn estimated_sq_distance_asymmetric(&self, q_rotated_unit: &[f32], q_norm: f32) -> f32 {
140        debug_assert_eq!(q_rotated_unit.len(), self.dim);
141        let d = self.dim;
142        let inv_sqrt_d = 1.0 / (d as f32).sqrt();
143        // Σᵢ sign(x_rot,i) · q_rot,i  without materialising signs: bit = 1
144        // means +1, bit = 0 means −1.
145        let mut ip = 0.0f32;
146        for (i, &q_i) in q_rotated_unit.iter().enumerate() {
147            let bit_set = (self.words[i / 64] >> (63 - (i % 64))) & 1 == 1;
148            ip += if bit_set { q_i } else { -q_i };
149        }
150        let unit_ip = ip * inv_sqrt_d;
151        let est_ip = q_norm * self.norm * unit_ip;
152        q_norm * q_norm + self.norm * self.norm - 2.0 * est_ip
153    }
154}
155
156/// Pack bits from a boolean slice into u64 words (for testing/utilities).
157pub fn pack_bits(bits: &[bool]) -> Vec<u64> {
158    let n_words = (bits.len() + 63) / 64;
159    let mut words = vec![0u64; n_words];
160    for (i, &b) in bits.iter().enumerate() {
161        if b {
162            words[i / 64] |= 1u64 << (63 - (i % 64));
163        }
164    }
165    words
166}
167
168/// Unpack u64 words back into a bool slice of length `dim`.
169pub fn unpack_bits(words: &[u64], dim: usize) -> Vec<bool> {
170    (0..dim)
171        .map(|i| words[i / 64] & (1u64 << (63 - (i % 64))) != 0)
172        .collect()
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn pack_unpack_roundtrip() {
181        let bits: Vec<bool> = (0..130).map(|i| i % 3 == 0).collect();
182        let words = pack_bits(&bits);
183        let unpacked = unpack_bits(&words, 130);
184        assert_eq!(bits, unpacked);
185    }
186
187    #[test]
188    fn xnor_self_is_all_ones() {
189        let v: Vec<f32> = (0..64)
190            .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
191            .collect();
192        let code = BinaryCode::encode(&v, 1.0);
193        let agreement = code.xnor_popcount(&code);
194        assert_eq!(
195            agreement, 64,
196            "self-agreement should be D=64, got {agreement}"
197        );
198    }
199
200    #[test]
201    fn xnor_opposite_is_zero() {
202        let v: Vec<f32> = (0..64)
203            .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
204            .collect();
205        let neg_v: Vec<f32> = v.iter().map(|&x| -x).collect();
206        let code = BinaryCode::encode(&v, 1.0);
207        let code_neg = BinaryCode::encode(&neg_v, 1.0);
208        let agreement = code.xnor_popcount(&code_neg);
209        assert_eq!(agreement, 0, "opposite vectors should have 0 agreement");
210    }
211
212    /// Bug surfaced by the deep review: at `D % 64 != 0` the padding bits in
213    /// the last word are zero in every code, so XNOR-popcount counts them as
214    /// matches. `masked_xnor_popcount` must not count padding.
215    #[test]
216    fn masked_popcount_handles_non_aligned_dim() {
217        // D=100 → 2 u64 words, 28 padding bits.
218        let v: Vec<f32> = (0..100)
219            .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
220            .collect();
221        let neg_v: Vec<f32> = v.iter().map(|&x| -x).collect();
222        let code = BinaryCode::encode(&v, 1.0);
223        let code_neg = BinaryCode::encode(&neg_v, 1.0);
224        // Raw would read 0 matches + 28 padding matches = 28 (wrong).
225        let raw = code.xnor_popcount(&code_neg);
226        assert_eq!(
227            raw, 28,
228            "raw xnor should count padding as matches (bug demo)"
229        );
230        // Masked must report 0 matches.
231        let masked = code.masked_xnor_popcount(&code_neg);
232        assert_eq!(
233            masked, 0,
234            "masked xnor must ignore padding bits; got {masked}"
235        );
236        // Self-compare: every real bit matches, padding is masked.
237        let self_masked = code.masked_xnor_popcount(&code);
238        assert_eq!(self_masked, 100);
239    }
240
241    #[test]
242    fn masked_popcount_matches_raw_when_aligned() {
243        // D=128 is 64-aligned, so masked == raw.
244        let v: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1).sin()).collect();
245        let w: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1).cos()).collect();
246        let ca = BinaryCode::encode(&v, 1.0);
247        let cb = BinaryCode::encode(&w, 1.0);
248        assert_eq!(ca.xnor_popcount(&cb), ca.masked_xnor_popcount(&cb));
249    }
250
251    #[test]
252    fn estimated_distance_self_is_near_zero() {
253        // A unit vector against itself should estimate distance ≈ 0.
254        let v: Vec<f32> = (0..128).map(|i| (i as f32 / 128.0).sin()).collect();
255        let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
256        let unit: Vec<f32> = v.iter().map(|&x| x / norm).collect();
257        let code = BinaryCode::encode(&unit, 1.0);
258        let est = code.estimated_sq_distance(&code);
259        // Symmetric Charikar estimator on the same code: cos(π·(1−D/D))=1 → est=0.
260        assert!(
261            est.abs() < 1e-5,
262            "self sq-distance estimate too large: {est}"
263        );
264    }
265
266    #[test]
267    fn asymmetric_matches_symmetric_in_sign() {
268        // The asymmetric IP estimator and the symmetric cos-angle estimator
269        // should agree on which of two candidates is closer (even when the
270        // magnitudes differ) — they encode the same angular signal.
271        use rand::{Rng as _, SeedableRng as _};
272        let mut rng = rand::rngs::StdRng::seed_from_u64(11);
273        let d = 128;
274        let q: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
275        let q_norm: f32 = q.iter().map(|&x| x * x).sum::<f32>().sqrt();
276        let q_unit: Vec<f32> = q.iter().map(|&x| x / q_norm).collect();
277        let qc = BinaryCode::encode(&q_unit, q_norm);
278
279        let near: Vec<f32> = q.iter().map(|&x| x + rng.gen::<f32>() * 0.1).collect();
280        let far: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
281
282        let encode_one = |v: &[f32]| {
283            let n: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
284            let u: Vec<f32> = v.iter().map(|&x| x / n).collect();
285            BinaryCode::encode(&u, n)
286        };
287        let cn = encode_one(&near);
288        let cf = encode_one(&far);
289
290        // Symmetric
291        let s_near = cn.estimated_sq_distance(&qc);
292        let s_far = cf.estimated_sq_distance(&qc);
293        // Asymmetric: the "rotated unit query" here is just q_unit (no
294        // rotation since we're testing the estimator math directly).
295        let a_near = cn.estimated_sq_distance_asymmetric(&q_unit, q_norm);
296        let a_far = cf.estimated_sq_distance_asymmetric(&q_unit, q_norm);
297        assert!(s_near < s_far, "symmetric ordering wrong");
298        assert!(a_near < a_far, "asymmetric ordering wrong");
299    }
300}