Skip to main content

rvf_quant/
binary.rs

1//! Binary Quantization — 32x compression (1 bit per dimension).
2//!
3//! Used for the **Cold** (Tier 2) tier. Encodes only the sign of each
4//! dimension and uses Hamming distance for comparison.
5
6use alloc::vec;
7use alloc::vec::Vec;
8
9/// Encode a float vector to binary: 1 bit per dimension (sign bit).
10///
11/// Bit layout: dimension `d` maps to bit `d % 8` of byte `d / 8`.
12/// A positive value (>= 0) is encoded as 1, negative as 0.
13pub fn encode_binary(vector: &[f32]) -> Vec<u8> {
14    let num_bytes = vector.len().div_ceil(8);
15    let mut bits = vec![0u8; num_bytes];
16    for (d, &val) in vector.iter().enumerate() {
17        if val >= 0.0 {
18            bits[d / 8] |= 1 << (d % 8);
19        }
20    }
21    bits
22}
23
24/// Decode binary codes back to an approximate float vector.
25///
26/// Each bit is decoded to +1.0 (set) or -1.0 (unset).
27pub fn decode_binary(bits: &[u8], dim: usize) -> Vec<f32> {
28    let mut vector = Vec::with_capacity(dim);
29    for d in 0..dim {
30        let byte_idx = d / 8;
31        let bit_idx = d % 8;
32        if byte_idx < bits.len() && (bits[byte_idx] >> bit_idx) & 1 == 1 {
33            vector.push(1.0);
34        } else {
35            vector.push(-1.0);
36        }
37    }
38    vector
39}
40
41/// Compute the Hamming distance between two binary-encoded vectors.
42///
43/// Processes data in u64 chunks (8 bytes at a time) using `count_ones()`
44/// which maps to hardware POPCNT on supported platforms. Falls back to
45/// byte-by-byte processing for the remainder.
46pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
47    assert_eq!(a.len(), b.len(), "binary vectors must have equal length");
48    let n = a.len();
49    let chunks = n / 8;
50    let remainder = n % 8;
51    let mut dist = 0u32;
52
53    // Process 8 bytes at a time using u64 popcount.
54    for i in 0..chunks {
55        let offset = i * 8;
56        let xa = u64::from_le_bytes([
57            a[offset],
58            a[offset + 1],
59            a[offset + 2],
60            a[offset + 3],
61            a[offset + 4],
62            a[offset + 5],
63            a[offset + 6],
64            a[offset + 7],
65        ]);
66        let xb = u64::from_le_bytes([
67            b[offset],
68            b[offset + 1],
69            b[offset + 2],
70            b[offset + 3],
71            b[offset + 4],
72            b[offset + 5],
73            b[offset + 6],
74            b[offset + 7],
75        ]);
76        dist += (xa ^ xb).count_ones();
77    }
78
79    // Handle remainder bytes.
80    let base = chunks * 8;
81    for i in 0..remainder {
82        dist += (a[base + i] ^ b[base + i]).count_ones();
83    }
84    dist
85}
86
87/// SIMD-accelerated Hamming distance (stub; falls back to scalar
88/// when the `simd` feature is not enabled or unavailable).
89#[cfg(feature = "simd")]
90pub fn hamming_distance_simd(a: &[u8], b: &[u8]) -> u32 {
91    // Future: VPOPCNTDQ / CNT implementation.
92    hamming_distance(a, b)
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn encode_decode_round_trip() {
101        let v = vec![1.0, -0.5, 0.3, -2.0, 0.0, 0.1, -0.1, 0.9];
102        let bits = encode_binary(&v);
103        let decoded = decode_binary(&bits, v.len());
104
105        // Check sign preservation
106        for (d, (&orig, &dec)) in v.iter().zip(decoded.iter()).enumerate() {
107            if orig >= 0.0 {
108                assert_eq!(dec, 1.0, "dim {d}: expected +1 for val {orig}");
109            } else {
110                assert_eq!(dec, -1.0, "dim {d}: expected -1 for val {orig}");
111            }
112        }
113    }
114
115    #[test]
116    fn hamming_self_is_zero() {
117        let v = vec![1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5];
118        let bits = encode_binary(&v);
119        assert_eq!(hamming_distance(&bits, &bits), 0);
120    }
121
122    #[test]
123    fn hamming_opposite_is_max() {
124        let v1 = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
125        let v2 = vec![-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0];
126        let b1 = encode_binary(&v1);
127        let b2 = encode_binary(&v2);
128        assert_eq!(hamming_distance(&b1, &b2), 8);
129    }
130
131    #[test]
132    fn hamming_matches_naive() {
133        let v1 = vec![
134            1.0, -1.0, 0.5, -0.5, 0.1, -0.1, 0.9, -0.9, 0.3, -0.3, 0.7, -0.7, 0.2, -0.2, 0.8, -0.8,
135        ];
136        let v2 = vec![
137            -1.0, 1.0, -0.5, 0.5, -0.1, 0.1, -0.9, 0.9, -0.3, 0.3, -0.7, 0.7, -0.2, 0.2, -0.8, 0.8,
138        ];
139        let b1 = encode_binary(&v1);
140        let b2 = encode_binary(&v2);
141
142        // All signs are flipped -> hamming distance = 16
143        assert_eq!(hamming_distance(&b1, &b2), 16);
144
145        // Naive computation for verification
146        let mut naive_dist = 0u32;
147        for d in 0..16 {
148            let s1 = if v1[d] >= 0.0 { 1 } else { 0 };
149            let s2 = if v2[d] >= 0.0 { 1 } else { 0 };
150            if s1 != s2 {
151                naive_dist += 1;
152            }
153        }
154        assert_eq!(hamming_distance(&b1, &b2), naive_dist);
155    }
156
157    #[test]
158    fn non_multiple_of_8_dimensions() {
159        let v = vec![1.0, -1.0, 0.5, -0.5, 0.1]; // 5 dims
160        let bits = encode_binary(&v);
161        assert_eq!(bits.len(), 1); // ceil(5/8) = 1
162        let decoded = decode_binary(&bits, 5);
163        assert_eq!(decoded.len(), 5);
164        assert_eq!(decoded[0], 1.0);
165        assert_eq!(decoded[1], -1.0);
166        assert_eq!(decoded[4], 1.0); // 0.1 >= 0
167    }
168}